stable_pretraining package#
Subpackages#
- stable_pretraining.backbone package
- Submodules
- stable_pretraining.backbone.aggregator module
- stable_pretraining.backbone.convmixer module
- stable_pretraining.backbone.mae module
- stable_pretraining.backbone.mlp module
- stable_pretraining.backbone.probe module
- stable_pretraining.backbone.resnet9 module
- stable_pretraining.backbone.utils module
- Module contents
- stable_pretraining.callbacks package
- Submodules
- stable_pretraining.callbacks.checkpoint_sklearn module
- stable_pretraining.callbacks.cleanup module
- stable_pretraining.callbacks.clip_zero_shot module
- stable_pretraining.callbacks.cpu_offload module
- stable_pretraining.callbacks.earlystop module
- stable_pretraining.callbacks.embedding_cache module
- stable_pretraining.callbacks.env_info module
- stable_pretraining.callbacks.factories module
- stable_pretraining.callbacks.image_retrieval module
- stable_pretraining.callbacks.knn module
- stable_pretraining.callbacks.latent_viz module
- stable_pretraining.callbacks.lidar module
- stable_pretraining.callbacks.probe module
- stable_pretraining.callbacks.queue module
- stable_pretraining.callbacks.rankme module
- stable_pretraining.callbacks.teacher_student module
- stable_pretraining.callbacks.trainer_info module
- stable_pretraining.callbacks.utils module
- stable_pretraining.callbacks.wd_schedule module
- stable_pretraining.callbacks.writer module
- Module contents
- stable_pretraining.data package
- Submodules
- stable_pretraining.data.collate module
- stable_pretraining.data.dataset_stats module
- stable_pretraining.data.datasets module
- stable_pretraining.data.download module
- stable_pretraining.data.masking module
- stable_pretraining.data.module module
- stable_pretraining.data.sampler module
- stable_pretraining.data.synthetic_data module
- stable_pretraining.data.transforms module
AdditiveGaussianCenterCropColorJitterComposeConditionalContextTargetsMultiBlockMaskControlledTransformGaussianBlurLambdaMultiViewTransformPILGaussianBlurPatchMaskingRGBRandomChannelPermutationRandomContiguousTemporalSamplerRandomCropRandomGrayscaleRandomHorizontalFlipRandomMaskRandomResizedCropRandomRotationRandomSolarizeResizeRoundRobinMultiViewTransformRoutingTransformToImageTransformUniformTemporalSubsampleWrapTorchTransformrandom_seed()set_seed()to_image()
- stable_pretraining.data.utils module
- Module contents
CategoricalCollatorDataModuleDatasetExponentialMixtureNoiseModelExponentialNormalNoiseModelFromTorchDatasetGMMHFDatasetMinariEpisodeDatasetMinariStepsDatasetRandomBatchSamplerRepeatedRandomSamplerSubsetSupervisedBatchSamplerbulk_download()download()fold_views()generate_perlin_noise_2d()perlin_noise_3d()random_split()swiss_roll()
- stable_pretraining.losses package
- Submodules
- stable_pretraining.losses.dino module
- stable_pretraining.losses.joint_embedding module
- stable_pretraining.losses.multimodal module
- stable_pretraining.losses.reconstruction module
- stable_pretraining.losses.utils module
- Module contents
- stable_pretraining.optim package
- stable_pretraining.utils package
- Submodules
- stable_pretraining.utils.batch_utils module
- stable_pretraining.utils.config module
- stable_pretraining.utils.data_generation module
- stable_pretraining.utils.distance_metrics module
- stable_pretraining.utils.distributed module
- stable_pretraining.utils.error_handling module
- stable_pretraining.utils.gdrive_utils module
- stable_pretraining.utils.inspection_utils module
- stable_pretraining.utils.lightning_patch module
- stable_pretraining.utils.log_reader module
- stable_pretraining.utils.nn_modules module
- stable_pretraining.utils.read_csv_logger module
- stable_pretraining.utils.timm_to_hf_hub module
- stable_pretraining.utils.visualization module
- Module contents
BatchNorm1dNoBiasCSVLogAutoSummarizerEMAFullGatherLayerGDriveUploaderImageToVideoEncoderL2NormNormalizeOrderedQueueUnsortedQueueadapt_resnet_for_lowres()all_gather()all_reduce()broadcast_param_to_list()compute_pairwise_distances()compute_pairwise_distances_chunked()detach_tensors()dict_values()execute_from_config()find_module()format_df_to_latex()generate_dae_samples()generate_dm_samples()generate_ssl_samples()generate_sup_samples()get_data_from_batch_or_outputs()get_required_fn_parameters()is_dist_avail_and_initialized()load_hparams_from_ckpt()replace_module()rgetattr()rsetattr()with_hf_retry_ratelimit()
Submodules#
stable_pretraining.cli module#
Command-line interface for Stable SSL training.
- stable_pretraining.cli.dump_csv_logs(dir: str = <typer.models.ArgumentInfo object>, output_name: str = <typer.models.ArgumentInfo object>, agg: str = <typer.models.ArgumentInfo object>)[source]#
Compress CSV logs to the smallest possible format with aggregation.
- stable_pretraining.cli.run(config: str = <typer.models.ArgumentInfo object>, overrides: ~typing.List[str] | None = <typer.models.ArgumentInfo object>)[source]#
Execute experiment with the specified config.
Examples
spt run config.yaml
spt run config.yaml -m
spt run config.yaml trainer.max_epochs=100
stable_pretraining.config module#
Configuration classes specifying default parameters for stable-SSL.
- stable_pretraining.config.collapse_nested_dict(cfg: dict | object, level_separator: str = '.', _base_name: str = None, _flat_cfg: dict = None) dict[source]#
Parse a Hydra config and make it readable for wandb (flatten).
- Parameters:
cfg (Union[dict, object]) – The original (Hydra) nested dict.
level_separator (str, optional) – The string to separate level names. Defaults to “.”.
_base_name (str, optional) – The parent string, used for recursion only, users should ignore. Defaults to None.
_flat_cfg (dict, optional) – The flattened config, used for recursion only, users should ignore. Defaults to None.
- Returns:
Flat config.
- Return type:
- stable_pretraining.config.instantiate_from_config(cfg: dict | DictConfig) Any[source]#
Main entry point for config-based training.
This function handles the complete instantiation of a training setup from config: - Recursively instantiates all components - Creates Manager if trainer/module/data are present - Returns appropriate object based on config structure
- Parameters:
cfg – Complete configuration dictionary or DictConfig
- Returns:
Manager instance if config contains trainer/module/data, otherwise returns instantiated config dict
- stable_pretraining.config.recursive_instantiate(cfg: dict | DictConfig, parent_objects: dict = None) dict[source]#
Recursively instantiate all components in config with dependency resolution.
- Parameters:
cfg – Configuration dictionary or DictConfig with _target_ fields
parent_objects – Optional dict of already instantiated objects for dependencies
- Returns:
Dictionary of instantiated components
stable_pretraining.forward module#
Forward functions for self-supervised learning methods.
This module provides pre-defined forward functions for various SSL methods that can be used with the Module class. These functions define the training logic for each method and can be specified in YAML configs or Python code.
Example
Using in a YAML config:
module:
_target_: stable_pretraining.Module
forward: stable_pretraining.forward.simclr_forward
backbone: ...
projector: ...
Using in Python code:
from stable_pretraining import Module
from stable_pretraining.forward import simclr_forward
module = Module(forward=simclr_forward, backbone=backbone, projector=projector)
- stable_pretraining.forward.barlow_twins_forward(self, batch, stage)[source]#
Forward function for Barlow Twins.
Barlow Twins learns representations by making the cross-correlation matrix between embeddings of augmented views as close to the identity matrix as possible, reducing redundancy while maintaining invariance.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head (typically with BN and high dimension) - barlow_loss: Barlow Twins loss function
batch – Either a list of view dicts (from MultiViewTransform) or a single dict (for validation/single-view)
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: Barlow Twins loss (during training only)
’label’: Labels if present (for probes/callbacks)
- Return type:
Dictionary containing
Note
Introduced in the Barlow Twins paper [Zbontar et al., 2021].
- stable_pretraining.forward.byol_forward(self, batch, stage)[source]#
Forward function for BYOL (Bootstrap Your Own Latent).
BYOL learns representations without negative pairs by using a momentum-based target network and predicting target projections from online projections.
- Parameters:
self – Module instance with required attributes: - backbone: TeacherStudentWrapper for feature extraction - projector: TeacherStudentWrapper for projection head - predictor: Online network predictor - byol_loss: BYOL loss function (optional, uses MSE if not provided)
batch – Either a list of view dicts (from MultiViewTransform) or a single dict (for validation/single-view)
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from teacher backbone (EMA target)
’loss’: BYOL loss between predictions and targets (during training)
’label’: Labels if present (for probes/callbacks)
- Return type:
Dictionary containing
Note
Introduced in the BYOL paper [Grill et al., 2020].
- stable_pretraining.forward.dino_forward(self, batch, stage)[source]#
Forward function for DINO (self-DIstillation with NO labels).
DINO learns representations through self-distillation where a student network is trained to match the output of a teacher network (EMA of student) on different augmented views. Global views are processed by both networks while local views are only processed by the student.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: TeacherStudentWrapper for feature extraction - projector: TeacherStudentWrapper for projection head - dino_loss: DINOv1Loss instance (required, pass spt.losses.DINOv1Loss()) - warmup_temperature_teacher (float): Starting teacher temperature - temperature_teacher (float): Final teacher temperature - warmup_epochs_temperature_teacher (int): Epochs to warm up temperature
batch – Either a list of view dicts (from MultiViewTransform) or a single dict (for validation/single-view). For multi-crop: First 2 views should be global crops, rest are local crops
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from teacher backbone
’loss’: DINO distillation loss (during training only)
’label’: Labels if present (for probes/callbacks)
- Return type:
Dictionary containing
Note
Introduced in the DINO paper [Caron et al., 2021]. Requires TeacherStudentWrapper for both backbone and projector, and assumes first 2 views in batch are global views.
- stable_pretraining.forward.dinov2_forward(self, batch, stage)[source]#
Forward function for DINOv2 with iBOT.
DINOv2 combines two self-supervised losses: - DINO: CLS token distillation between global views - iBOT: Masked patch prediction
Both losses use Sinkhorn-Knopp normalization for optimal transport.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: TeacherStudentWrapper for feature extraction (ViT) - projector: TeacherStudentWrapper for CLS token projection head - patch_projector: TeacherStudentWrapper for patch projection head - dinov2_loss: DINOv2Loss instance combining DINO + iBOT - warmup_temperature_teacher (float): Starting teacher temperature - temperature_teacher (float): Final teacher temperature - warmup_epochs_temperature_teacher (int): Epochs to warm up temperature - mask_ratio (float): Ratio of patches to mask for iBOT (default: 0.3)
batch – Either a list of view dicts (from MultiViewTransform) or a single dict (for validation/single-view). For multi-crop: First 2 views should be global crops, rest are local crops
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from teacher backbone
’loss’: Combined DINOv2 loss (DINO + iBOT)
’label’: Labels if present (for probes/callbacks)
- Return type:
Dictionary containing
Note
Introduced in the DINOv2 paper. Requires TeacherStudentWrapper for backbone, projector, and patch_projector. Assumes first 2 views in batch are global views.
- stable_pretraining.forward.nnclr_forward(self, batch, stage)[source]#
Forward function for NNCLR (Nearest-Neighbor Contrastive Learning).
NNCLR learns representations by using the nearest neighbor of an augmented view from a support set of past embeddings as a positive pair. This encourages the model to learn representations that are similar for semantically similar instances, not just for different augmentations of the same instance.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head for embedding transformation - predictor: Prediction head used for the online view - nnclr_loss: NTXent contrastive loss function
batch – Either a list of view dicts (from MultiViewTransform) or a single dict (for validation/single-view)
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: NTXent contrastive loss (during training only)
’nnclr_support_set’: Projections to be added to the support set queue
’label’: Labels if present (for probes/callbacks)
- Return type:
Dictionary containing
Note
Introduced in the NNCLR paper [Dwibedi et al., 2021].
- stable_pretraining.forward.simclr_forward(self, batch, stage)[source]#
Forward function for SimCLR (Simple Contrastive Learning of Representations).
SimCLR learns representations by maximizing agreement between differently augmented views of the same image via a contrastive loss in the latent space.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head mapping features to latent space - simclr_loss: NT-Xent contrastive loss function
batch – Either a list of view dicts (from MultiViewTransform) or a single dict (for validation/single-view)
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: NT-Xent contrastive loss (during training only)
’label’: Labels if present (for probes/callbacks)
- Return type:
Dictionary containing
Note
Introduced in the SimCLR paper [Chen et al., 2020].
- stable_pretraining.forward.supervised_forward(self, batch, stage)[source]#
Forward function for standard supervised training.
This function implements traditional supervised learning with labels, useful for baseline comparisons and fine-tuning pre-trained models.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - classifier: Classification head (e.g., Linear layer) - supervised_loss: Loss function for supervised learning
batch – Input batch dictionary containing: - ‘image’: Tensor of images [N, C, H, W] - ‘label’: Ground truth labels [N] (optional, for loss computation)
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’logits’: Classification predictions
’loss’: Supervised loss (if labels provided)
- Return type:
Dictionary containing
Note
Unlike SSL methods, this function uses actual labels for training and is primarily used for evaluation or supervised baselines.
- stable_pretraining.forward.swav_forward(self, batch, stage)[source]#
Forward function for SwAV (Swapping Assignments between Views).
SwAV learns representations by predicting the cluster assignment (code) of one view from the representation of another view. For small-batch training, this function manages a feature queue to stabilize the training process.
- stable_pretraining.forward.vicreg_forward(self, batch, stage)[source]#
Forward function for VICReg (Variance-Invariance-Covariance Regularization).
VICReg learns representations using three criteria: variance (maintaining information), invariance (to augmentations), and covariance (decorrelating features).
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head for embedding transformation - vicreg_loss: VICReg loss with variance, invariance, covariance terms
batch – Either a list of view dicts (from MultiViewTransform) or a single dict (for validation/single-view)
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: Combined VICReg loss (during training only)
’label’: Labels if present (for probes/callbacks)
- Return type:
Dictionary containing
Note
Introduced in the VICReg paper [].
stable_pretraining.manager module#
- class stable_pretraining.manager.Manager(*args, **kwargs)[source]#
Bases:
CheckpointableManages training with logging, scheduling, and checkpointing support.
- Parameters:
trainer (Union[dict, DictConfig, pl.Trainer]) – PyTorch Lightning trainer configuration or instance.
module (Union[dict, DictConfig, pl.LightningModule]) – Lightning module configuration or instance.
data (Union[dict, DictConfig, pl.LightningDataModule]) – Data module configuration or instance.
seed (int, optional) – Random seed for reproducibility. Defaults to None.
ckpt_path (str, optional) – Path to checkpoint for resuming training. Defaults to “last”.
compile (bool, optional) – Should we compile the given module. Defaults to False.
- property instantiated_data#
- property instantiated_module#
stable_pretraining.module module#
- class stable_pretraining.module.Module(*args, forward: callable = None, hparams: dict = None, **kwargs)[source]#
Bases:
LightningModulePyTorch Lightning module using manual optimization with multi-optimizer support.
Core usage - Provide a custom forward(self, batch, stage) via the forward argument at init. - During training, forward must return a dict with state[“loss”] (a single joint loss).
When multiple optimizers are configured, this joint loss is used for all optimizers.
Optimizer configuration (self.optim) - Single optimizer:
{“optimizer”: str|dict|partial|Class, “scheduler”: <see below>, “interval”: “step”|”epoch”, “frequency”: int} - Optimizer accepted forms:
string name (e.g., “AdamW”, “SGD”) from torch.optim
dict: {“type”: “AdamW”, “lr”: 1e-3, …}
functools.partial: partial(torch.optim.AdamW, lr=1e-3)
optimizer class: torch.optim.AdamW
Multiple optimizers: {
- name: {
“modules”: “regex”, # assign params by module-name pattern (children inherit) “optimizer”: str|dict|partial|Class, # optimizer factory (same accepted forms as above) “scheduler”: str|dict|partial|Class, # flexible scheduler config (see below) “interval”: “step”|”epoch”, # scheduler interval “frequency”: int, # optimizer step frequency “monitor”: str # (optional) for ReduceLROnPlateau; alternatively set inside scheduler dict
}, …
}
Parameter assignment (multi-optimizer) - Modules are matched by regex on their qualified name. Children inherit the parent’s assignment
unless they match a more specific pattern. Only direct parameters of each module are collected to avoid duplication.
Schedulers (flexible) - Accepted forms: string name (e.g., “CosineAnnealingLR”, “StepLR”), dict with {“type”: “…”, …},
functools.partial, or a scheduler class. Smart defaults are applied when params are omitted for common schedulers (CosineAnnealingLR, OneCycleLR, StepLR, ExponentialLR, ReduceLROnPlateau, LinearLR, ConstantLR). For ReduceLROnPlateau, a monitor key is added (default: “val_loss”). You may specify monitor either alongside the optimizer config (top level) or inside the scheduler dict itself.
The resulting Lightning scheduler dict includes interval and frequency (or scheduler_frequency).
Training loop behavior - Manual optimization (automatic_optimization = False). - Gradient accumulation: scales loss by 1/N where N = Trainer.accumulate_grad_batches and steps on the boundary. - Per-optimizer step frequency: each optimizer steps only when its frequency boundary is met (in addition to accumulation boundary). - Gradient clipping: uses Trainer’s gradient_clip_val and gradient_clip_algorithm before each step. - Returns the state dict from forward unchanged for logging/inspection.
- configure_optimizers()[source]#
Configure optimizers and schedulers for manual optimization.
- Returns:
Optimizer configuration with optional learning rate scheduler. For single optimizer: Returns a dict with optimizer and lr_scheduler. For multiple optimizers: Returns a tuple of (optimizers, schedulers).
- Return type:
Example
Multi-optimizer configuration with module pattern matching and schedulers:
>>> # Simple single optimizer with scheduler >>> self.optim = { ... "optimizer": partial(torch.optim.AdamW, lr=1e-3), ... "scheduler": "CosineAnnealingLR", # Uses smart defaults ... "interval": "step", ... "frequency": 1, ... }
>>> # Multi-optimizer with custom scheduler configs >>> self.optim = { ... "encoder_opt": { ... "modules": "encoder", # Matches 'encoder' and all children ... "optimizer": {"type": "AdamW", "lr": 1e-3}, ... "scheduler": { ... "type": "OneCycleLR", ... "max_lr": 1e-3, ... "total_steps": 10000, ... }, ... "interval": "step", ... "frequency": 1, ... }, ... "head_opt": { ... "modules": ".*head$", # Matches modules ending with 'head' ... "optimizer": "SGD", ... "scheduler": { ... "type": "ReduceLROnPlateau", ... "mode": "max", ... "patience": 5, ... "factor": 0.5, ... }, ... "monitor": "val_accuracy", # Required for ReduceLROnPlateau ... "interval": "epoch", ... "frequency": 2, ... }, ... }
With model structure: - encoder -> encoder_opt (matches “encoder”) - encoder.layer1 -> encoder_opt (inherits from parent) - encoder.layer1.conv -> encoder_opt (inherits from encoder.layer1) - classifier_head -> head_opt (matches “.*head$”) - classifier_head.linear -> head_opt (inherits from parent) - decoder -> None (no match, no parameters collected)
- forward(*args, **kwargs)[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- named_parameters(with_callbacks=True, prefix: str = '', recurse: bool = True)[source]#
Override to globally exclude callback-related parameters.
Excludes parameters that belong to
self.callbacks_modulesorself.callbacks_metrics. This prevents accidental optimization of callback/metric internals, even if external code callsself.parameters()orself.named_parameters()directly.- Parameters:
with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.
prefix (str, optional) – Prefix to prepend to parameter names. Defaults to “”.
recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.
- Yields:
tuple[str, torch.nn.Parameter] – Name and parameter pairs.
- on_save_checkpoint(checkpoint)[source]#
Offload checkpoint tensors to CPU to reduce GPU memory usage during save.
This method intercepts the checkpoint saving process and recursively moves all PyTorch tensors (model weights, optimizer states, scheduler states) from GPU to CPU before writing to disk. This prevents GPU OOM issues when checkpointing large models (e.g., 2B+ parameters with optimizer states).
- Parameters:
checkpoint (dict) – Lightning checkpoint dictionary containing: - state_dict: Model parameters (moved to CPU) - optimizer_states: Optimizer state dicts (moved to CPU) - lr_schedulers: LR scheduler states (moved to CPU) - Other keys: Custom objects, metadata (left unchanged)
- Behavior:
Processes standard Lightning checkpoint keys (state_dict, optimizer_states, lr_schedulers)
Recursively traverses dicts, lists, and tuples to find tensors
Moves all torch.Tensor objects to CPU
Skips custom objects (returns unchanged)
Logs GPU memory freed and processing time
Non-destructive: Checkpoint loading/resuming works normally
- Side Effects:
Modifies checkpoint dict in-place (tensors moved to CPU)
Temporarily increases CPU memory during offload
Adds ~2-5 seconds to checkpoint save time for 2B models
Frees ~8-12GB GPU memory for 2B model + optimizer states
- Custom Objects:
Custom objects in the checkpoint are NOT modified and will be logged as warnings. These include: custom classes, numpy arrays, primitives, etc. They are safely skipped and preserved in the checkpoint.
- Raises:
Exception – If tensor offload fails for any checkpoint key, logs error but allows checkpoint save to proceed (non-fatal).
Example
For a 2B parameter model with AdamW optimizer: - Before: ~12GB GPU memory spike on rank 0 during checkpoint save - After: ~0.2GB GPU memory spike, ~10-12GB freed - Checkpoint save time: +2-3 seconds - Resume from checkpoint: Works normally, tensors auto-loaded to GPU
Notes
Only rank 0 saves checkpoints in DDP, so only rank 0 sees memory benefit
Does not affect checkpoint contents or ability to resume training
Safe for standard PyTorch/Lightning use cases
If using FSDP/DeepSpeed, consider strategy-specific checkpointing instead
See also
PyTorch Lightning ModelCheckpoint callback
torch.Tensor.cpu() for device transfer behavior
- parameters(with_callbacks=True, recurse: bool = True)[source]#
Override to route through the filtered
named_parametersimplementation.
- predict_step(batch, batch_idx)[source]#
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)as predictions won’t be returned.- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- test_step(batch, batch_idx)[source]#
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- training_step(batch, batch_idx)[source]#
Manual optimization training step with support for multiple optimizers.
Expected output from forward during training (stage=”fit”): - state[“loss”]: torch.Tensor - Single joint loss for all optimizers
When multiple optimizers are configured, the same loss is used for all of them. Each optimizer updates its assigned parameters based on gradients from this joint loss.
- validation_step(batch, batch_idx)[source]#
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
stable_pretraining.run module#
Universal experiment runner for stable-pretraining using Hydra configs.
This script provides a unified entry point for all experiments (training, evaluation, etc.) via configuration files. It supports both single-file configs and modular Hydra composition.
- Usage:
# Run with a config file python -m stable_pretraining.run –config-path ../examples –config-name simclr_cifar10
# Run with config and override parameters python -m stable_pretraining.run –config-path ../examples –config-name simclr_cifar10 module.optimizer.lr=0.01 trainer.max_epochs=200
# Run hyperparameter sweep python -m stable_pretraining.run –multirun –config-path ../examples –config-name simclr_cifar10 module.optimizer.lr=0.001,0.01,0.1
stable_pretraining.static module#
- class stable_pretraining.static.MetaStatic[source]#
Bases:
typeMetaclass that enables dict-like behavior on the TIMM_PARAMETERS class.
- class stable_pretraining.static.TIMM_EMBEDDINGS[source]#
Bases:
objectThread-safe, lazy-loaded registry for TIMM (PyTorch Image Models) embedding names, accessed via class-level indexing.
This class provides a mapping from string keys to lists of embedding names, loaded on first access from a JSON file located at ‘assets/static_timm.json’ relative to the source file. The data is cached as a class attribute after the first load, and subsequent accesses are served from memory. The class is intended to be used as a static registry, e.g.:
>>> names = TIMM_EMBEDDINGS["resnet50"] >>> print(names) # List of embedding names for 'resnet50'
Notes
The data is loaded only once per process and is shared across all uses of the class.
Thread-safe: concurrent first-time access is protected by a class-level lock.
The class depends on the presence of the ‘assets/static_timm.json’ file two directories above the source file.
The class assumes the __file__ attribute is available and points to the current file.
The class attribute _data is private and shared.
Logging and printing occur on first load for debugging.
File system access and JSON parsing are required at runtime.
- Raises:
RuntimeError – If the assets file is missing.
json.JSONDecodeError – If the file is not valid JSON.
KeyError – If the requested key is not present in the data.
Example
>>> embeddings = TIMM_EMBEDDINGS["vit_base_patch16_224"] >>> print(embeddings)
- class stable_pretraining.static.TIMM_PARAMETERS[source]#
Bases:
objectThread-safe singleton class for accessing TIMM (Timm Image Models) parameters.
This class provides lazy-loaded, cached access to TIMM model parameters stored in a static JSON file. It implements a dict-like interface with thread-safe initialization and defensive copying to prevent mutation of cached data.
- Usage:
# Access parameters by key params = TIMM_PARAMETERS[‘model_name’]
# Iterate over keys for key in TIMM_PARAMETERS.keys():
print(key)
# Iterate over values for values in TIMM_PARAMETERS.values():
print(values)
# Iterate over items for key, values in TIMM_PARAMETERS.items():
print(f”{key}: {values}”)
Note
All methods return copies of the data to prevent accidental mutation of the internal cache.
Module contents#
- class stable_pretraining.EarlyStopping(mode: str = 'min', milestones: dict[int, float] = None, metric_name: str = None, patience: int = 10)[source]#
Bases:
ModuleEarly stopping mechanism with support for metric milestones and patience.
This module provides flexible early stopping capabilities that can halt training based on metric performance. It supports both milestone-based stopping (stop if metric doesn’t reach target by specific epochs) and patience-based stopping (stop if metric doesn’t improve for N epochs).
- Parameters:
mode – Optimization direction - ‘min’ for metrics to minimize (e.g., loss), ‘max’ for metrics to maximize (e.g., accuracy).
milestones – Dict mapping epoch numbers to target metric values. Training stops if targets are not met at specified epochs.
metric_name – Name of the metric to monitor if metric is a dict.
patience – Number of epochs with no improvement before stopping.
Example
>>> early_stop = EarlyStopping(mode="max", milestones={10: 0.8, 20: 0.9}) >>> # Stops if accuracy < 0.8 at epoch 10 or < 0.9 at epoch 20
- class stable_pretraining.ImageRetrieval(pl_module, name: str, input: str, query_col: str, retrieval_col: str | List[str], metrics, features_dim: tuple[int] | list[int] | int, normalizer: str = None)[source]#
Bases:
CallbackImage Retrieval evaluator for self-supervised learning.
- The implementation follows:
- NAME = 'ImageRetrieval'#
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the val epoch ends.
- class stable_pretraining.LiDAR(name: str, target: str, queue_length: int, target_shape: int | Iterable[int], n_classes: int = 100, samples_per_class: int = 10, delta: float = 0.0001, epsilon: float = 1e-08)[source]#
Bases:
CallbackLiDAR (Linear Discriminant Analysis Rank) monitor using queue discovery.
LiDAR measures the effective rank of learned representations using Linear Discriminant Analysis (LDA). It computes the exponential of the entropy of the eigenvalue distribution from the LDA transformation, providing a metric between 1 and min(d, n_classes - 1) where d is the feature dimension, indicating how many dimensions are effectively being used.
This implementation is based on Thilak et al. “LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures” (arXiv:2312.04000).
IMPORTANT: Surrogate Class Formation Requirement#
The LiDAR paper requires that each “surrogate class” consists of q augmented views of the same clean sample. The current implementation chunks the queue sequentially into groups of size samples_per_class. For faithful reproduction of the paper:
Ensure the upstream queue pushes q contiguous augmentations of each clean sample
OR implement ID-based grouping to ensure each group contains views of the same sample
Without proper grouping, the metric may not accurately reflect the paper’s methodology.
The metric helps detect: - Dimensional collapse in self-supervised learning - Loss of representational capacity - Over-regularization effects
- param name:
Unique identifier for this callback instance
- param target:
Key in batch dict containing the feature embeddings to monitor
- param queue_length:
Size of the circular buffer for caching embeddings
- param target_shape:
Shape of the target embeddings (e.g., 768 for 768-dim features)
- param n_classes:
Number of surrogate classes (clean samples) for LDA computation
- param samples_per_class:
Number of augmented samples per class
- param delta:
Regularization constant added to within-class covariance (default: 1e-4)
- param epsilon:
Small constant for numerical stability (default: 1e-8)
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0) None[source]#
Compute LiDAR metric on the first validation batch only.
- class stable_pretraining.LoggingCallback[source]#
Bases:
CallbackDisplays validation metrics in a color-coded formatted table.
This callback creates a visually appealing table of all validation metrics at the end of each validation epoch. Metrics are color-coded for better readability in terminal outputs.
Features: - Automatic sorting of metrics by name - Color coding: blue for metric names, green for values - Filters out internal metrics (log, progress_bar)
- class stable_pretraining.Manager(*args, **kwargs)[source]#
Bases:
CheckpointableManages training with logging, scheduling, and checkpointing support.
- Parameters:
trainer (Union[dict, DictConfig, pl.Trainer]) – PyTorch Lightning trainer configuration or instance.
module (Union[dict, DictConfig, pl.LightningModule]) – Lightning module configuration or instance.
data (Union[dict, DictConfig, pl.LightningDataModule]) – Data module configuration or instance.
seed (int, optional) – Random seed for reproducibility. Defaults to None.
ckpt_path (str, optional) – Path to checkpoint for resuming training. Defaults to “last”.
compile (bool, optional) – Should we compile the given module. Defaults to False.
- property instantiated_data#
- property instantiated_module#
- class stable_pretraining.Module(*args, forward: callable = None, hparams: dict = None, **kwargs)[source]#
Bases:
LightningModulePyTorch Lightning module using manual optimization with multi-optimizer support.
Core usage - Provide a custom forward(self, batch, stage) via the forward argument at init. - During training, forward must return a dict with state[“loss”] (a single joint loss).
When multiple optimizers are configured, this joint loss is used for all optimizers.
Optimizer configuration (self.optim) - Single optimizer:
{“optimizer”: str|dict|partial|Class, “scheduler”: <see below>, “interval”: “step”|”epoch”, “frequency”: int} - Optimizer accepted forms:
string name (e.g., “AdamW”, “SGD”) from torch.optim
dict: {“type”: “AdamW”, “lr”: 1e-3, …}
functools.partial: partial(torch.optim.AdamW, lr=1e-3)
optimizer class: torch.optim.AdamW
Multiple optimizers: {
- name: {
“modules”: “regex”, # assign params by module-name pattern (children inherit) “optimizer”: str|dict|partial|Class, # optimizer factory (same accepted forms as above) “scheduler”: str|dict|partial|Class, # flexible scheduler config (see below) “interval”: “step”|”epoch”, # scheduler interval “frequency”: int, # optimizer step frequency “monitor”: str # (optional) for ReduceLROnPlateau; alternatively set inside scheduler dict
}, …
}
Parameter assignment (multi-optimizer) - Modules are matched by regex on their qualified name. Children inherit the parent’s assignment
unless they match a more specific pattern. Only direct parameters of each module are collected to avoid duplication.
Schedulers (flexible) - Accepted forms: string name (e.g., “CosineAnnealingLR”, “StepLR”), dict with {“type”: “…”, …},
functools.partial, or a scheduler class. Smart defaults are applied when params are omitted for common schedulers (CosineAnnealingLR, OneCycleLR, StepLR, ExponentialLR, ReduceLROnPlateau, LinearLR, ConstantLR). For ReduceLROnPlateau, a monitor key is added (default: “val_loss”). You may specify monitor either alongside the optimizer config (top level) or inside the scheduler dict itself.
The resulting Lightning scheduler dict includes interval and frequency (or scheduler_frequency).
Training loop behavior - Manual optimization (automatic_optimization = False). - Gradient accumulation: scales loss by 1/N where N = Trainer.accumulate_grad_batches and steps on the boundary. - Per-optimizer step frequency: each optimizer steps only when its frequency boundary is met (in addition to accumulation boundary). - Gradient clipping: uses Trainer’s gradient_clip_val and gradient_clip_algorithm before each step. - Returns the state dict from forward unchanged for logging/inspection.
- configure_optimizers()[source]#
Configure optimizers and schedulers for manual optimization.
- Returns:
Optimizer configuration with optional learning rate scheduler. For single optimizer: Returns a dict with optimizer and lr_scheduler. For multiple optimizers: Returns a tuple of (optimizers, schedulers).
- Return type:
Example
Multi-optimizer configuration with module pattern matching and schedulers:
>>> # Simple single optimizer with scheduler >>> self.optim = { ... "optimizer": partial(torch.optim.AdamW, lr=1e-3), ... "scheduler": "CosineAnnealingLR", # Uses smart defaults ... "interval": "step", ... "frequency": 1, ... }
>>> # Multi-optimizer with custom scheduler configs >>> self.optim = { ... "encoder_opt": { ... "modules": "encoder", # Matches 'encoder' and all children ... "optimizer": {"type": "AdamW", "lr": 1e-3}, ... "scheduler": { ... "type": "OneCycleLR", ... "max_lr": 1e-3, ... "total_steps": 10000, ... }, ... "interval": "step", ... "frequency": 1, ... }, ... "head_opt": { ... "modules": ".*head$", # Matches modules ending with 'head' ... "optimizer": "SGD", ... "scheduler": { ... "type": "ReduceLROnPlateau", ... "mode": "max", ... "patience": 5, ... "factor": 0.5, ... }, ... "monitor": "val_accuracy", # Required for ReduceLROnPlateau ... "interval": "epoch", ... "frequency": 2, ... }, ... }
With model structure: - encoder -> encoder_opt (matches “encoder”) - encoder.layer1 -> encoder_opt (inherits from parent) - encoder.layer1.conv -> encoder_opt (inherits from encoder.layer1) - classifier_head -> head_opt (matches “.*head$”) - classifier_head.linear -> head_opt (inherits from parent) - decoder -> None (no match, no parameters collected)
- forward(*args, **kwargs)[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- named_parameters(with_callbacks=True, prefix: str = '', recurse: bool = True)[source]#
Override to globally exclude callback-related parameters.
Excludes parameters that belong to
self.callbacks_modulesorself.callbacks_metrics. This prevents accidental optimization of callback/metric internals, even if external code callsself.parameters()orself.named_parameters()directly.- Parameters:
with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.
prefix (str, optional) – Prefix to prepend to parameter names. Defaults to “”.
recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.
- Yields:
tuple[str, torch.nn.Parameter] – Name and parameter pairs.
- on_save_checkpoint(checkpoint)[source]#
Offload checkpoint tensors to CPU to reduce GPU memory usage during save.
This method intercepts the checkpoint saving process and recursively moves all PyTorch tensors (model weights, optimizer states, scheduler states) from GPU to CPU before writing to disk. This prevents GPU OOM issues when checkpointing large models (e.g., 2B+ parameters with optimizer states).
- Parameters:
checkpoint (dict) – Lightning checkpoint dictionary containing: - state_dict: Model parameters (moved to CPU) - optimizer_states: Optimizer state dicts (moved to CPU) - lr_schedulers: LR scheduler states (moved to CPU) - Other keys: Custom objects, metadata (left unchanged)
- Behavior:
Processes standard Lightning checkpoint keys (state_dict, optimizer_states, lr_schedulers)
Recursively traverses dicts, lists, and tuples to find tensors
Moves all torch.Tensor objects to CPU
Skips custom objects (returns unchanged)
Logs GPU memory freed and processing time
Non-destructive: Checkpoint loading/resuming works normally
- Side Effects:
Modifies checkpoint dict in-place (tensors moved to CPU)
Temporarily increases CPU memory during offload
Adds ~2-5 seconds to checkpoint save time for 2B models
Frees ~8-12GB GPU memory for 2B model + optimizer states
- Custom Objects:
Custom objects in the checkpoint are NOT modified and will be logged as warnings. These include: custom classes, numpy arrays, primitives, etc. They are safely skipped and preserved in the checkpoint.
- Raises:
Exception – If tensor offload fails for any checkpoint key, logs error but allows checkpoint save to proceed (non-fatal).
Example
For a 2B parameter model with AdamW optimizer: - Before: ~12GB GPU memory spike on rank 0 during checkpoint save - After: ~0.2GB GPU memory spike, ~10-12GB freed - Checkpoint save time: +2-3 seconds - Resume from checkpoint: Works normally, tensors auto-loaded to GPU
Notes
Only rank 0 saves checkpoints in DDP, so only rank 0 sees memory benefit
Does not affect checkpoint contents or ability to resume training
Safe for standard PyTorch/Lightning use cases
If using FSDP/DeepSpeed, consider strategy-specific checkpointing instead
See also
PyTorch Lightning ModelCheckpoint callback
torch.Tensor.cpu() for device transfer behavior
- parameters(with_callbacks=True, recurse: bool = True)[source]#
Override to route through the filtered
named_parametersimplementation.
- predict_step(batch, batch_idx)[source]#
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)as predictions won’t be returned.- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- test_step(batch, batch_idx)[source]#
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- training_step(batch, batch_idx)[source]#
Manual optimization training step with support for multiple optimizers.
Expected output from forward during training (stage=”fit”): - state[“loss”]: torch.Tensor - Single joint loss for all optimizers
When multiple optimizers are configured, the same loss is used for all of them. Each optimizer updates its assigned parameters based on gradients from this joint loss.
- validation_step(batch, batch_idx)[source]#
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- class stable_pretraining.ModuleSummary[source]#
Bases:
CallbackLogs detailed module parameter statistics in a formatted table.
This callback provides a comprehensive overview of all modules in the model, showing the number of trainable, non-trainable, uninitialized parameters, and buffers for each module. This helps understand model architecture and parameter distribution.
The summary is displayed during the setup phase and includes: - Module name and hierarchy - Trainable parameter count - Non-trainable (frozen) parameter count - Uninitialized parameter count (for lazy modules) - Buffer count (non-parameter persistent state)
- class stable_pretraining.OnlineKNN(name: str, input: str, target: str, queue_length: int, metrics: Dict, input_dim: Tuple[int, ...] | List[int] | int | None = None, target_dim: int | None = None, k: int = 5, temperature: float = 0.07, chunk_size: int = -1, distance_metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean')[source]#
Bases:
CallbackWeighted K-Nearest Neighbors online evaluator using queue discovery.
This callback implements a weighted KNN classifier that evaluates the quality of learned representations during training. It automatically discovers or creates OnlineQueue callbacks to maintain circular buffers of features and labels, then uses this cached data to compute KNN predictions during validation.
The KNN evaluation is performed by: 1. Finding k nearest neighbors in the feature space 2. Weighting neighbors by inverse distance with temperature scaling 3. Using weighted voting to produce class predictions 4. Computing specified metrics on the predictions
- Parameters:
name – Unique identifier for this callback instance. Used for logging and storing metrics.
input – Key in batch dict containing input features to evaluate.
target – Key in batch dict containing ground truth target labels.
queue_length – Size of the circular buffer for caching features and labels. Larger values provide more representative samples but use more memory.
metrics – Dictionary of metrics to compute during validation. Keys are metric names, values are metric instances (e.g., torchmetrics.Accuracy).
input_dim – Expected dimensionality of input features. Can be int, tuple/list (will be flattened to product), or None to accept any dimension.
target_dim – Expected dimensionality of targets. None accepts any dimension.
k – Number of nearest neighbors to consider for voting. Default is 5.
temperature – Temperature parameter for distance weighting. Lower values give more weight to closer neighbors. Default is 0.07.
chunk_size – Batch size for memory-efficient distance computation. Set to -1 to compute all distances at once. Default is -1.
distance_metric – Distance metric for finding nearest neighbors. Options are ‘euclidean’, ‘squared_euclidean’, ‘cosine’, ‘manhattan’. Default is ‘euclidean’.
- Raises:
ValueError – If k <= 0, temperature <= 0, or chunk_size is invalid.
Note
The callback automatically handles distributed training by gathering data
Mixed precision is supported through automatic dtype conversion
Predictions are stored in batch dict with key ‘{name}_preds’
Metrics are logged with prefix ‘eval/{name}_’
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, dataloader_idx: int = 0) None[source]#
Compute KNN predictions during validation.
- class stable_pretraining.OnlineProbe(module: LightningModule, name: str, input: str, target: str, probe: Module, loss_fn: callable = None, optimizer: str | dict | partial | Optimizer | None = None, scheduler: str | dict | partial | LRScheduler | None = None, accumulate_grad_batches: int = 1, gradient_clip_val: float = None, gradient_clip_algorithm: str = 'norm', metrics: dict | tuple | list | Metric | None = None)[source]#
Bases:
TrainableCallbackOnline probe for evaluating learned representations during self-supervised training.
This callback implements the standard linear evaluation protocol by training a probe (typically a linear classifier) on top of frozen features from the main model. The probe is trained simultaneously with the main model but maintains its own optimizer, scheduler, and training loop. This allows monitoring representation quality throughout training without modifying the base model.
Key features: - Automatic gradient detachment to prevent probe gradients affecting the main model - Independent optimizer and scheduler management - Support for gradient accumulation - Mixed precision training compatibility through automatic dtype conversion - Metric tracking and logging
- Parameters:
module – The spt.LightningModule to probe.
name – Unique identifier for this probe instance. Used for logging and storing metrics/modules.
input – Key in batch dict or outputs dict containing input features to probe.
target – Key in batch dict containing ground truth target labels.
probe – The probe module to train. Can be a nn.Module instance, callable that returns a module, or Hydra config to instantiate.
loss_fn – Loss function for probe training (e.g., nn.CrossEntropyLoss()).
optimizer –
Optimizer configuration for the probe. Can be: - str: optimizer name (e.g., “AdamW”, “SGD”, “LARS”) - dict: {“type”: “AdamW”, “lr”: 1e-3, …} - partial: pre-configured optimizer factory - optimizer instance or callable - None: uses LARS(lr=0.1, clip_lr=True, eta=0.02, exclude_bias_n_norm=True,
weight_decay=0), which is the standard for SSL linear probes (default)
scheduler – Learning rate scheduler configuration. Can be: - str: scheduler name (e.g., “CosineAnnealingLR”, “StepLR”) - dict: {“type”: “CosineAnnealingLR”, “T_max”: 1000, …} - partial: pre-configured scheduler factory - scheduler instance or callable - None: uses ConstantLR(factor=1.0), maintaining constant learning rate (default)
accumulate_grad_batches – Number of batches to accumulate gradients before optimizer step. Default is 1 (no accumulation).
metrics – Metrics to track during training/validation. Can be dict, list, tuple, or single metric instance.
Note
The probe module is stored in pl_module.callbacks_modules[name]
Metrics are stored in pl_module.callbacks_metrics[name]
Predictions are stored in batch dict with key ‘{name}_preds’
Loss is logged as ‘train/{name}_loss’
Metrics are logged with prefix ‘train/{name}_’ and ‘eval/{name}_’
- class stable_pretraining.OnlineWriter(names: str, path: str | Path, during: str | list[str], every_k_epochs: int = -1, save_last_epoch: bool = False, save_sanity_check: bool = False, all_gather: bool = True)[source]#
Bases:
CallbackWrites specified batch data to disk during training and validation.
This callback enables selective saving of batch data (e.g., features, predictions, embeddings) to disk at specified intervals during training. It’s useful for debugging, visualization, and analysis of model behavior during training.
Features: - Flexible saving schedule (every k epochs, last epoch, sanity check) - Support for distributed training with optional all_gather - Automatic directory creation - Configurable for different training phases (train, val, test)
- Parameters:
names – Name(s) of the batch keys to save. Can be string or list of strings.
path – Directory path where files will be saved.
during – Training phase(s) when to save (‘train’, ‘val’, ‘test’, or list).
every_k_epochs – Save every k epochs. -1 means every epoch.
save_last_epoch – Whether to save on the last training epoch.
save_sanity_check – Whether to save during sanity check phase.
all_gather – Whether to gather data across all distributed processes.
Files are saved with naming pattern: {phase}_{name}_epoch{epoch}_batch{batch}.pt
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the predict batch ends.
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the test batch ends.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the train batch ends.
Note
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- class stable_pretraining.RankMe(name: str, target: str, queue_length: int, target_shape: int | Iterable[int])[source]#
Bases:
CallbackRankMe (effective rank) monitor using queue discovery.
RankMe measures the effective rank of feature representations by computing the exponential of the entropy of normalized singular values. This metric helps detect dimensional collapse in self-supervised learning.
- Parameters:
name – Unique name for this callback instance
target – Key in batch dict containing the feature embeddings to monitor
queue_length – Required queue length
target_shape – Shape of the target embeddings (e.g., 768 for 768-dim features)
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0) None[source]#
Compute RankMe metric on the first validation batch only.
- class stable_pretraining.SklearnCheckpoint[source]#
Bases:
CallbackCallback for saving and loading sklearn models in PyTorch Lightning checkpoints.
This callback automatically detects sklearn models (Regressors and Classifiers) attached to the Lightning module and handles their serialization/deserialization during checkpoint save/load operations. This is necessary because sklearn models are not natively supported by PyTorch’s checkpoint system.
The callback will: 1. Automatically discover sklearn models attached to the Lightning module 2. Save them to the checkpoint dictionary during checkpoint saving 3. Restore them from the checkpoint during checkpoint loading 4. Log information about discovered sklearn modules during setup
Note
Only attributes that are sklearn RegressorMixin or ClassifierMixin instances are saved
Private attributes (starting with ‘_’) are ignored
The callback will raise an error if a sklearn model name conflicts with existing checkpoint keys
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the checkpoint dictionary that will be saved.
- class stable_pretraining.TeacherStudentCallback(update_frequency: int = 1, update_after_backward: bool = False)[source]#
Bases:
CallbackAutomatically updates TeacherStudentWrapper instances during training.
This callback handles the EMA (Exponential Moving Average) updates for any TeacherStudentWrapper instances found in the model. It updates both the teacher parameters and the EMA coefficient schedule.
The callback automatically detects all TeacherStudentWrapper instances in the model hierarchy and updates them at the appropriate times during training.
- Parameters:
Example
>>> backbone = ResNet18() >>> wrapped_backbone = TeacherStudentWrapper(backbone) >>> module = ssl.Module(backbone=wrapped_backbone, ...) >>> trainer = pl.Trainer(callbacks=[TeacherStudentCallback()])
- on_after_backward(trainer: Trainer, pl_module: LightningModule) None[source]#
Update teacher models after backward pass if update_after_backward is True.
- class stable_pretraining.TeacherStudentWrapper(student: Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1)[source]#
Bases:
ModuleBackbone wrapper that implements teacher-student distillation via EMA.
This is a wrapper for backbones that creates a teacher model as an exponential moving average (EMA) of the student model. It should be passed as the backbone to stable_pretraining.Module and accessed via forward_student() and forward_teacher() methods in your custom forward function.
The teacher model is updated by taking a running average of the student’s parameters and buffers. When ema_coefficient == 0.0, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients.
- Usage example:
backbone = ResNet18() wrapped_backbone = TeacherStudentWrapper(backbone) module = ssl.Module(
backbone=wrapped_backbone, projector=projector, forward=forward_with_teacher_student, …
)
- Parameters:
student (torch.nn.Module) – The student model whose parameters will be tracked.
warm_init (bool, optional) – If True, performs an initialization step to match the student’s parameters immediately. Default is True.
base_ema_coefficient (float, optional) – EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996.
final_ema_coefficient (float, optional) – EMA decay factor at the end of training. Default is 1.
- forward(*args, **kwargs)[source]#
Forward pass through either the student or teacher network.
You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that.
- forward_student(*args, **kwargs)[source]#
Forward pass through the student network. Gradients will flow normally.
- forward_teacher(*args, **kwargs)[source]#
Forward pass through the teacher network.
By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow.
- update_ema_coefficient(epoch: int, total_epochs: int)[source]#
Update the EMA coefficient following a cosine schedule.
- The EMA coefficient is updated following a cosine schedule:
ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi))
- update_teacher()[source]#
Perform one EMA update step on the teacher’s parameters.
- The update rule is:
teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param
This is done in a no_grad context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable.
Everything is updated, including buffers (e.g. batch norm running averages).
- class stable_pretraining.TrainerInfo[source]#
Bases:
CallbackLinks the trainer to the DataModule for enhanced functionality.
This callback establishes a bidirectional connection between the trainer and DataModule, enabling the DataModule to access trainer information such as device placement, distributed training state, and other runtime configurations.
This is particularly useful for DataModules that need to adapt their behavior based on trainer configuration (e.g., device-aware data loading, distributed sampling adjustments).
Note
Only works with DataModule instances that have a set_pl_trainer method. A warning is logged if using a custom DataModule without this method.