stable_pretraining.callbacks package#
Submodules#
stable_pretraining.callbacks.checkpoint_sklearn module#
- class stable_pretraining.callbacks.checkpoint_sklearn.SklearnCheckpoint[source]#
Bases:
CallbackCallback for saving and loading sklearn models in PyTorch Lightning checkpoints.
This callback automatically detects sklearn models (Regressors and Classifiers) attached to the Lightning module and handles their serialization/deserialization during checkpoint save/load operations. This is necessary because sklearn models are not natively supported by PyTorch’s checkpoint system.
The callback will: 1. Automatically discover sklearn models attached to the Lightning module 2. Save them to the checkpoint dictionary during checkpoint saving 3. Restore them from the checkpoint during checkpoint loading 4. Log information about discovered sklearn modules during setup
Note
Only attributes that are sklearn RegressorMixin or ClassifierMixin instances are saved
Private attributes (starting with ‘_’) are ignored
The callback will raise an error if a sklearn model name conflicts with existing checkpoint keys
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the checkpoint dictionary that will be saved.
- class stable_pretraining.callbacks.checkpoint_sklearn.StrictCheckpointCallback(strict: bool = True)[source]#
Bases:
CallbackA PyTorch Lightning callback that controls strict checkpoint loading behavior.
- class stable_pretraining.callbacks.checkpoint_sklearn.WandbCheckpoint[source]#
Bases:
CallbackCallback for saving and loading sklearn models in PyTorch Lightning checkpoints.
This callback automatically detects sklearn models (Regressors and Classifiers) attached to the Lightning module and handles their serialization/deserialization during checkpoint save/load operations. This is necessary because sklearn models are not natively supported by PyTorch’s checkpoint system.
The callback will: 1. Automatically discover sklearn models attached to the Lightning module 2. Save them to the checkpoint dictionary during checkpoint saving 3. Restore them from the checkpoint during checkpoint loading 4. Log information about discovered sklearn modules during setup
Note
Only attributes that are sklearn RegressorMixin or ClassifierMixin instances are saved
Private attributes (starting with ‘_’) are ignored
The callback will raise an error if a sklearn model name conflicts with existing checkpoint keys
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the checkpoint dictionary that will be saved.
stable_pretraining.callbacks.cleanup module#
- class stable_pretraining.callbacks.cleanup.CleanUpCallback(slurm_patterns: Sequence[str] | None = None, search_paths: Sequence[str] | None = None, delete_hydra: bool = True, dry_run: bool = False)[source]#
Bases:
CallbackPyTorch Lightning callback to monitor and clean up SLURM and Hydra files during and after training.
At each epoch, prints the names and sizes of SLURM and Hydra files. At the end of successful training, deletes those files and prints a summary. Safe for DDP (only rank 0 prints/deletes).
- Parameters:
slurm_patterns (Optional[Sequence[str]]) – Glob patterns for SLURM files to monitor/delete. Defaults to [“slurm-.out”, “slurm-.err”].
search_paths (Optional[Sequence[str]]) – Directories to search for SLURM files. Defaults to [os.getcwd(), $SLURM_SUBMIT_DIR if set].
delete_hydra (bool) – Whether to delete Hydra artifacts (.hydra directory and hydra.log). Defaults to True.
dry_run (bool) – If True, only print what would be deleted, do not actually delete. Defaults to False.
Example
```python from pytorch_lightning import Trainer
cleanup_cb = CleanUpCallback() trainer = Trainer(callbacks=[cleanup_cb]) ```
- Example (custom patterns, dry run):
```python cleanup_cb = CleanUpCallback(
slurm_patterns=[“slurm-.out”, “myjob-.log”], search_paths=[“/scratch/logs”, “/tmp”], delete_hydra=False, dry_run=True,
- on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#
Mark that an exception occurred, so files will not be deleted at the end.
- Parameters:
trainer (Trainer) – The PyTorch Lightning Trainer.
pl_module (LightningModule) – The LightningModule being trained.
exception (BaseException) – The exception that was raised.
stable_pretraining.callbacks.clip_zero_shot module#
- class stable_pretraining.callbacks.clip_zero_shot.CLIPZeroShot(name: str, image_key: str, class_key: str, class_names: list[str], image_backbone: Module, text_backbone: Module, tokenizer_fn: Callable[[str | list[str]], Tensor], metrics: dict | tuple | list | Metric | None = None)[source]#
Bases:
CallbackZero-shot classification evaluator for CLIP-style models.
This callback computes zero-shot predictions by computing the similarity between the image embeddings and the class embeddings.
- Parameters:
name – Unique identifier for this callback instance (used as log prefix and registry key).
image_key – Key in batch or outputs containing input images or precomputed image features.
tokens_key – Key in batch containing tokenized text.
class_key – Key in batch containing ground-truth class indices (0..C-1, aligned with class_names order).
class_names – List of class names in index order.
image_backbone – Module/callable to encode images into embeddings.
text_backbone – Module/callable to encode tokenized text into embeddings.
tokenizer_fn – Callable that maps str | list[str] -> tensor of shape (T,).
metrics – Dict of torchmetrics to compute on validation (e.g., {“top1”: MulticlassAccuracy(…)}).
stable_pretraining.callbacks.cpu_offload module#
- class stable_pretraining.callbacks.cpu_offload.CPUOffloadCallback(offload_keys: List[str] | None = None, log_skipped: bool = False)[source]#
Bases:
CallbackOffload checkpoint tensors to CPU during save to reduce GPU memory usage.
This callback intercepts checkpoint saving and moves all PyTorch tensors (model weights, optimizer states, scheduler states) from GPU to CPU before writing to disk. Prevents GPU OOM for large models (2B+ parameters).
Compatible Strategies: - DDP (DistributedDataParallel) - Single GPU training
Incompatible Strategies (auto-disabled): - FSDP (uses sharded checkpointing) - DeepSpeed (has custom checkpoint mechanism) - Other sharding strategies
- Parameters:
offload_keys – Keys to offload. Defaults to [‘state_dict’, ‘optimizer_states’, ‘lr_schedulers’]
log_skipped – If False, only logs first/last 10 skipped objects (default: False). If True, logs all skipped objects.
Example
```python from lightning.pytorch import Trainer
# Just add to callbacks! trainer = Trainer(
strategy=”ddp”, # Compatible callbacks=[
CPUOffloadCallback(), # Auto-enables ModelCheckpoint(…),
],
)
- Benefits:
2B model + optimizer: ~12GB GPU memory freed on rank 0
No code changes needed in LightningModule
Safe resumption - tensors auto-loaded to correct device
Adds ~2-5s to checkpoint save time
Auto-detects incompatible strategies and disables itself
Notes
Only affects rank 0 in DDP (only rank that saves)
Custom objects in checkpoint are safely skipped
Does not affect checkpoint contents or resumption
Compatible with all PyTorch Lightning versions >= 2.0
- on_exception(trainer, pl_module, exception)[source]#
Called when an exception occurs during training.
stable_pretraining.callbacks.earlystop module#
- class stable_pretraining.callbacks.earlystop.EpochMilestones(milestones: dict[int, float], monitor: list[str] | str = None, contains: str = None, direction: str = 'max', after_validation: bool = True, strict: bool = True)[source]#
Bases:
CallbackPyTorch Lightning callback to stop training if a monitored metric does not meet specified thresholds at given epochs.
This callback allows you to define “milestones”—specific epochs at which a metric must surpass (or fall below) a given value. If the metric fails to meet the requirement at the milestone epoch, training is stopped early.
- Parameters:
metric_name (str) – The name of the metric to monitor (as logged in trainer.callback_metrics).
milestones (dict[int, float]) – A dictionary mapping epoch numbers (int) to required metric values (float). At each specified epoch, the metric is checked against the corresponding value.
direction (str, optional) – One of “max” or “min”. - “max”: Training stops if the metric is less than or equal to the milestone value. - “min”: Training stops if the metric is greater than or equal to the milestone value. Default is “max”.
after_validation (bool, optional) – If True (default), the metric is checked after validation (on_validation_end). If False, the metric is checked after training (on_training_end).
- Raises:
ValueError – If the specified metric is not found in trainer.callback_metrics at the milestone epoch.
Example
>>> milestones = {10: 0.2, 20: 0.5} >>> callback = EpochMilestones( ... metric_name="eva/accuracy", ... milestones=milestones, ... direction="max", ... after_validation=True, ... ) >>> trainer = pl.Trainer(callbacks=[callback])
stable_pretraining.callbacks.embedding_cache module#
- class stable_pretraining.callbacks.embedding_cache.EmbeddingCache(module_names: list, add_to_forward_output: bool = True)[source]#
Bases:
CallbackCache embedding from a module given their names.
Args: module_names (list of str): List of module names to hook (e.g., [‘layer1’, ‘encoder.block1’]). add_to_forward_output (bool): If True, enables merging cached outputs into the dict returned by forward.
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]#
Called when the test batch begins.
- on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]#
Called when the train batch begins.
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]#
Called when the validation batch begins.
stable_pretraining.callbacks.env_info module#
- class stable_pretraining.callbacks.env_info.EnvironmentDumpCallback(filename: str = 'environment.json', async_dump: bool = True)[source]#
Bases:
CallbackDumps complete environment configuration to enable exact reproduction.
DDP-safe: only runs on rank 0. Uses loguru for comprehensive logging of all operations.
Args: filename: Name of the file to save environment info async_dump: If True, runs the dump in a background thread (non-blocking)
stable_pretraining.callbacks.factories module#
stable_pretraining.callbacks.image_retrieval module#
- class stable_pretraining.callbacks.image_retrieval.ImageRetrieval(pl_module, name: str, input: str, query_col: str, retrieval_col: str | List[str], metrics, features_dim: tuple[int] | list[int] | int, normalizer: str = None)[source]#
Bases:
CallbackImage Retrieval evaluator for self-supervised learning.
- The implementation follows:
- NAME = 'ImageRetrieval'#
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the val epoch ends.
stable_pretraining.callbacks.knn module#
- class stable_pretraining.callbacks.knn.OnlineKNN(name: str, input: str, target: str, queue_length: int, metrics: Dict, input_dim: Tuple[int, ...] | List[int] | int | None = None, target_dim: int | None = None, k: int = 5, temperature: float = 0.07, chunk_size: int = -1, distance_metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean')[source]#
Bases:
CallbackWeighted K-Nearest Neighbors online evaluator using queue discovery.
This callback implements a weighted KNN classifier that evaluates the quality of learned representations during training. It automatically discovers or creates OnlineQueue callbacks to maintain circular buffers of features and labels, then uses this cached data to compute KNN predictions during validation.
The KNN evaluation is performed by: 1. Finding k nearest neighbors in the feature space 2. Weighting neighbors by inverse distance with temperature scaling 3. Using weighted voting to produce class predictions 4. Computing specified metrics on the predictions
- Parameters:
name – Unique identifier for this callback instance. Used for logging and storing metrics.
input – Key in batch dict containing input features to evaluate.
target – Key in batch dict containing ground truth target labels.
queue_length – Size of the circular buffer for caching features and labels. Larger values provide more representative samples but use more memory.
metrics – Dictionary of metrics to compute during validation. Keys are metric names, values are metric instances (e.g., torchmetrics.Accuracy).
input_dim – Expected dimensionality of input features. Can be int, tuple/list (will be flattened to product), or None to accept any dimension.
target_dim – Expected dimensionality of targets. None accepts any dimension.
k – Number of nearest neighbors to consider for voting. Default is 5.
temperature – Temperature parameter for distance weighting. Lower values give more weight to closer neighbors. Default is 0.07.
chunk_size – Batch size for memory-efficient distance computation. Set to -1 to compute all distances at once. Default is -1.
distance_metric – Distance metric for finding nearest neighbors. Options are ‘euclidean’, ‘squared_euclidean’, ‘cosine’, ‘manhattan’. Default is ‘euclidean’.
- Raises:
ValueError – If k <= 0, temperature <= 0, or chunk_size is invalid.
Note
The callback automatically handles distributed training by gathering data
Mixed precision is supported through automatic dtype conversion
Predictions are stored in batch dict with key ‘{name}_preds’
Metrics are logged with prefix ‘eval/{name}_’
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, dataloader_idx: int = 0) None[source]#
Compute KNN predictions during validation.
stable_pretraining.callbacks.latent_viz module#
Online latent space visualization callback with dimensionality reduction.
This callback learns a 2D projection of high-dimensional features while preserving neighborhood structure using a contrastive loss between high-D and low-D similarities.
- class stable_pretraining.callbacks.latent_viz.LatentViz(name: str, input: str, target: str | None, projection: Module, queue_length: int = 2048, k_neighbors: int = 15, n_negatives: int = 5, optimizer: str | dict | partial | Optimizer | None = None, scheduler: str | dict | partial | LRScheduler | None = None, accumulate_grad_batches: int = 1, update_interval: int = 10, warmup_epochs: int = 0, distance_metric: Literal['euclidean', 'cosine'] = 'euclidean', plot_interval: int = 10, save_dir: str | None = None, input_dim: int | tuple | list | None = None)[source]#
Bases:
TrainableCallbackOnline latent visualization callback with neighborhood-preserving dimensionality reduction.
This callback learns a 2D projection that preserves neighborhood structure from high-dimensional features. It uses a contrastive loss that attracts neighbors and repels non-neighbors in the 2D space.
- The loss function is:
L = -∑_{ij} P_{ij} log Q_{ij} + ∑_{i,j ∈ Neg(i)} log(1 - Q_{ij})
- where:
P_{ij} is the high-D neighborhood graph (based on k-NN)
Q_{ij} is the similarity in the learned 2D space
Neg(i) is the set of negative samples for point i
- Parameters:
name – Unique identifier for this callback instance.
input – Key in batch dict containing input features to visualize.
target – Optional key in batch dict containing labels for coloring plots. If None, points will be plotted without color coding.
projection – The projection module to train (maps high-D to 2D). Can be: - nn.Module instance - callable that returns a module - Hydra config to instantiate
queue_length – Size of the circular buffer for features.
k_neighbors – Number of nearest neighbors for building P matrix.
n_negatives – Number of negative samples per positive pair.
optimizer – Optimizer configuration. If None, uses Adam (recommended for DR tasks).
scheduler – Learning rate scheduler configuration. If None, uses ConstantLR.
accumulate_grad_batches – Number of batches to accumulate gradients.
update_interval – Update projection network every N training batches (default: 10).
warmup_epochs – Number of epochs to wait before starting projection training (default: 0). Allows main model to stabilize before learning 2D projections.
distance_metric – Metric for computing distances in high-D space.
plot_interval – Interval (in epochs) for plotting 2D visualization.
save_dir – Optional directory to save plots. If None, saves to ‘latent_viz_{name}’.
input_dim – Expected dimensionality of input features (for queue).
- on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int) None[source]#
Perform projection network training step.
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Plot 2D visualization at specified intervals.
- property projection_module#
Alias for self.module for backward compatibility.
stable_pretraining.callbacks.lidar module#
LiDAR (Linear Discriminant Analysis Rank) callback for monitoring representation quality.
- Based on:
Thilak et al. “LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures” arXiv:2312.04000 (2023)
- class stable_pretraining.callbacks.lidar.LiDAR(name: str, target: str, queue_length: int, target_shape: int | Iterable[int], n_classes: int = 100, samples_per_class: int = 10, delta: float = 0.0001, epsilon: float = 1e-08)[source]#
Bases:
CallbackLiDAR (Linear Discriminant Analysis Rank) monitor using queue discovery.
LiDAR measures the effective rank of learned representations using Linear Discriminant Analysis (LDA). It computes the exponential of the entropy of the eigenvalue distribution from the LDA transformation, providing a metric between 1 and min(d, n_classes - 1) where d is the feature dimension, indicating how many dimensions are effectively being used.
This implementation is based on Thilak et al. “LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures” (arXiv:2312.04000).
IMPORTANT: Surrogate Class Formation Requirement#
The LiDAR paper requires that each “surrogate class” consists of q augmented views of the same clean sample. The current implementation chunks the queue sequentially into groups of size samples_per_class. For faithful reproduction of the paper:
Ensure the upstream queue pushes q contiguous augmentations of each clean sample
OR implement ID-based grouping to ensure each group contains views of the same sample
Without proper grouping, the metric may not accurately reflect the paper’s methodology.
The metric helps detect: - Dimensional collapse in self-supervised learning - Loss of representational capacity - Over-regularization effects
- param name:
Unique identifier for this callback instance
- param target:
Key in batch dict containing the feature embeddings to monitor
- param queue_length:
Size of the circular buffer for caching embeddings
- param target_shape:
Shape of the target embeddings (e.g., 768 for 768-dim features)
- param n_classes:
Number of surrogate classes (clean samples) for LDA computation
- param samples_per_class:
Number of augmented samples per class
- param delta:
Regularization constant added to within-class covariance (default: 1e-4)
- param epsilon:
Small constant for numerical stability (default: 1e-8)
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0) None[source]#
Compute LiDAR metric on the first validation batch only.
stable_pretraining.callbacks.probe module#
- class stable_pretraining.callbacks.probe.OnlineProbe(module: LightningModule, name: str, input: str, target: str, probe: Module, loss_fn: callable = None, optimizer: str | dict | partial | Optimizer | None = None, scheduler: str | dict | partial | LRScheduler | None = None, accumulate_grad_batches: int = 1, gradient_clip_val: float = None, gradient_clip_algorithm: str = 'norm', metrics: dict | tuple | list | Metric | None = None)[source]#
Bases:
TrainableCallbackOnline probe for evaluating learned representations during self-supervised training.
This callback implements the standard linear evaluation protocol by training a probe (typically a linear classifier) on top of frozen features from the main model. The probe is trained simultaneously with the main model but maintains its own optimizer, scheduler, and training loop. This allows monitoring representation quality throughout training without modifying the base model.
Key features: - Automatic gradient detachment to prevent probe gradients affecting the main model - Independent optimizer and scheduler management - Support for gradient accumulation - Mixed precision training compatibility through automatic dtype conversion - Metric tracking and logging
- Parameters:
module – The spt.LightningModule to probe.
name – Unique identifier for this probe instance. Used for logging and storing metrics/modules.
input – Key in batch dict or outputs dict containing input features to probe.
target – Key in batch dict containing ground truth target labels.
probe – The probe module to train. Can be a nn.Module instance, callable that returns a module, or Hydra config to instantiate.
loss_fn – Loss function for probe training (e.g., nn.CrossEntropyLoss()).
optimizer –
Optimizer configuration for the probe. Can be: - str: optimizer name (e.g., “AdamW”, “SGD”, “LARS”) - dict: {“type”: “AdamW”, “lr”: 1e-3, …} - partial: pre-configured optimizer factory - optimizer instance or callable - None: uses LARS(lr=0.1, clip_lr=True, eta=0.02, exclude_bias_n_norm=True,
weight_decay=0), which is the standard for SSL linear probes (default)
scheduler – Learning rate scheduler configuration. Can be: - str: scheduler name (e.g., “CosineAnnealingLR”, “StepLR”) - dict: {“type”: “CosineAnnealingLR”, “T_max”: 1000, …} - partial: pre-configured scheduler factory - scheduler instance or callable - None: uses ConstantLR(factor=1.0), maintaining constant learning rate (default)
accumulate_grad_batches – Number of batches to accumulate gradients before optimizer step. Default is 1 (no accumulation).
metrics – Metrics to track during training/validation. Can be dict, list, tuple, or single metric instance.
Note
The probe module is stored in pl_module.callbacks_modules[name]
Metrics are stored in pl_module.callbacks_metrics[name]
Predictions are stored in batch dict with key ‘{name}_preds’
Loss is logged as ‘train/{name}_loss’
Metrics are logged with prefix ‘train/{name}_’ and ‘eval/{name}_’
stable_pretraining.callbacks.queue module#
Queue callback with unified size management and insertion order preservation.
This module provides a queue callback that uses OrderedQueue to maintain insertion order and implements intelligent queue sharing when multiple callbacks request the same data with different queue sizes.
- class stable_pretraining.callbacks.queue.OnlineQueue(key: str, queue_length: int, dim: int | tuple | None = None, dtype: dtype | None = None, gather_distributed: bool = False)[source]#
Bases:
CallbackCircular buffer callback with insertion order preservation and size unification.
This callback maintains an OrderedQueue that accumulates data from specified batch keys during training while preserving insertion order. It implements intelligent queue sharing: when multiple callbacks request the same data with different sizes, it uses a single queue with the maximum size and serves appropriate subsets.
Key features: - Maintains insertion order using OrderedQueue - Unified storage: one queue per key, shared across different size requests - Memory-efficient: no duplicate storage for same data - Size-based retrieval: each consumer gets exactly the amount they need
- Parameters:
key – The batch key whose tensor values will be queued at every training step.
queue_length – Number of elements this callback needs from the queue.
dim – Pre-allocate buffer with this shape. Can be int or tuple.
dtype – Pre-allocate buffer with this dtype.
gather_distributed – If True, gather queue data across all processes.
- data#
Property returning the requested number of most recent samples.
- actual_queue_length#
The actual size of the underlying shared queue.
- on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int) None[source]#
Append batch data to the shared queue.
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Clean up snapshot after validation.
- stable_pretraining.callbacks.queue.find_or_create_queue_callback(trainer: Trainer, key: str, queue_length: int, dim: int | tuple | None = None, dtype: dtype | None = None, gather_distributed: bool = False, create_if_missing: bool = True) OnlineQueue[source]#
Find or create an OnlineQueue callback with unified size management.
This function implements intelligent queue unification: - If a queue exists for the key with a different size, it reuses the same
underlying queue and adjusts its size if needed
Each callback gets exactly the amount of data it requests
Memory is optimized by sharing the same storage
- Parameters:
trainer – The Lightning trainer containing callbacks
key – The batch key to look for
queue_length – Number of samples this callback needs
dim – Required dimension (None means any)
dtype – Required dtype (None means any)
gather_distributed – Whether to gather across distributed processes
create_if_missing – If True, create queue when not found
- Returns:
The matching or newly created OnlineQueue callback
- Raises:
ValueError – If no matching queue is found and create_if_missing is False
stable_pretraining.callbacks.rankme module#
- class stable_pretraining.callbacks.rankme.RankMe(name: str, target: str, queue_length: int, target_shape: int | Iterable[int])[source]#
Bases:
CallbackRankMe (effective rank) monitor using queue discovery.
RankMe measures the effective rank of feature representations by computing the exponential of the entropy of normalized singular values. This metric helps detect dimensional collapse in self-supervised learning.
- Parameters:
name – Unique name for this callback instance
target – Key in batch dict containing the feature embeddings to monitor
queue_length – Required queue length
target_shape – Shape of the target embeddings (e.g., 768 for 768-dim features)
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0) None[source]#
Compute RankMe metric on the first validation batch only.
stable_pretraining.callbacks.teacher_student module#
Callback for automatic TeacherStudentWrapper EMA updates.
- class stable_pretraining.callbacks.teacher_student.TeacherStudentCallback(update_frequency: int = 1, update_after_backward: bool = False)[source]#
Bases:
CallbackAutomatically updates TeacherStudentWrapper instances during training.
This callback handles the EMA (Exponential Moving Average) updates for any TeacherStudentWrapper instances found in the model. It updates both the teacher parameters and the EMA coefficient schedule.
The callback automatically detects all TeacherStudentWrapper instances in the model hierarchy and updates them at the appropriate times during training.
- Parameters:
Example
>>> backbone = ResNet18() >>> wrapped_backbone = TeacherStudentWrapper(backbone) >>> module = ssl.Module(backbone=wrapped_backbone, ...) >>> trainer = pl.Trainer(callbacks=[TeacherStudentCallback()])
- on_after_backward(trainer: Trainer, pl_module: LightningModule) None[source]#
Update teacher models after backward pass if update_after_backward is True.
stable_pretraining.callbacks.trainer_info module#
- class stable_pretraining.callbacks.trainer_info.LoggingCallback[source]#
Bases:
CallbackDisplays validation metrics in a color-coded formatted table.
This callback creates a visually appealing table of all validation metrics at the end of each validation epoch. Metrics are color-coded for better readability in terminal outputs.
Features: - Automatic sorting of metrics by name - Color coding: blue for metric names, green for values - Filters out internal metrics (log, progress_bar)
- class stable_pretraining.callbacks.trainer_info.ModuleSummary[source]#
Bases:
CallbackLogs detailed module parameter statistics in a formatted table.
This callback provides a comprehensive overview of all modules in the model, showing the number of trainable, non-trainable, uninitialized parameters, and buffers for each module. This helps understand model architecture and parameter distribution.
The summary is displayed during the setup phase and includes: - Module name and hierarchy - Trainable parameter count - Non-trainable (frozen) parameter count - Uninitialized parameter count (for lazy modules) - Buffer count (non-parameter persistent state)
- class stable_pretraining.callbacks.trainer_info.SLURMInfo[source]#
Bases:
CallbackLinks the trainer to the DataModule for enhanced functionality.
This callback establishes a bidirectional connection between the trainer and DataModule, enabling the DataModule to access trainer information such as device placement, distributed training state, and other runtime configurations.
This is particularly useful for DataModules that need to adapt their behavior based on trainer configuration (e.g., device-aware data loading, distributed sampling adjustments).
Note
Only works with DataModule instances that have a set_pl_trainer method. A warning is logged if using a custom DataModule without this method.
- class stable_pretraining.callbacks.trainer_info.TrainerInfo[source]#
Bases:
CallbackLinks the trainer to the DataModule for enhanced functionality.
This callback establishes a bidirectional connection between the trainer and DataModule, enabling the DataModule to access trainer information such as device placement, distributed training state, and other runtime configurations.
This is particularly useful for DataModules that need to adapt their behavior based on trainer configuration (e.g., device-aware data loading, distributed sampling adjustments).
Note
Only works with DataModule instances that have a set_pl_trainer method. A warning is logged if using a custom DataModule without this method.
stable_pretraining.callbacks.utils module#
- class stable_pretraining.callbacks.utils.EarlyStopping(mode: str = 'min', milestones: dict[int, float] = None, metric_name: str = None, patience: int = 10)[source]#
Bases:
ModuleEarly stopping mechanism with support for metric milestones and patience.
This module provides flexible early stopping capabilities that can halt training based on metric performance. It supports both milestone-based stopping (stop if metric doesn’t reach target by specific epochs) and patience-based stopping (stop if metric doesn’t improve for N epochs).
- Parameters:
mode – Optimization direction - ‘min’ for metrics to minimize (e.g., loss), ‘max’ for metrics to maximize (e.g., accuracy).
milestones – Dict mapping epoch numbers to target metric values. Training stops if targets are not met at specified epochs.
metric_name – Name of the metric to monitor if metric is a dict.
patience – Number of epochs with no improvement before stopping.
Example
>>> early_stop = EarlyStopping(mode="max", milestones={10: 0.8, 20: 0.9}) >>> # Stops if accuracy < 0.8 at epoch 10 or < 0.9 at epoch 20
- class stable_pretraining.callbacks.utils.TrainableCallback(module: LightningModule, name: str, optimizer: str | dict | partial | Optimizer | None = None, scheduler: str | dict | partial | LRScheduler | None = None, accumulate_grad_batches: int = 1, gradient_clip_val: float = None, gradient_clip_algorithm: str = 'norm')[source]#
Bases:
CallbackBase callback class with optimizer and scheduler management.
This base class handles the common logic for callbacks that need their own optimizer and scheduler, including automatic inheritance from the main module’s configuration when not explicitly specified.
Subclasses should: 1. Call super().__init__() with appropriate parameters 2. Store their module configuration in self._module_config 3. Override configure_model() to create their specific module 4. Access their module via self.module property after setup
- configure_model(pl_module: LightningModule) Module[source]#
Initialize the module for this callback.
Subclasses must override this method to create their specific module.
- Parameters:
pl_module – The Lightning module being trained.
- Returns:
The initialized module.
- property module#
Access module from pl_module.callbacks_modules.
This property is only accessible after setup() has been called. The module is stored centrally in pl_module.callbacks_modules to avoid duplication in checkpoints.
- setup_optimizer(pl_module: LightningModule) None[source]#
Initialize optimizer with default LARS if not specified.
- stable_pretraining.callbacks.utils.format_metrics_as_dict(metrics)[source]#
Formats various metric input formats into a standardized dictionary structure.
This utility function handles multiple input formats for metrics and converts them into a consistent ModuleDict structure with separate train and validation metrics. This standardization simplifies metric handling across callbacks.
- Parameters:
metrics – Can be: - None: Returns empty train and val dicts - Single torchmetrics.Metric: Applied to validation only - Dict with ‘train’ and ‘val’ keys: Separated accordingly - Dict of metrics: All applied to validation - List/tuple of metrics: All applied to validation
- Returns:
ModuleDict with ‘_train’ and ‘_val’ keys, each containing metric ModuleDicts.
- Raises:
ValueError – If metrics format is invalid or contains non-torchmetric objects.
stable_pretraining.callbacks.wd_schedule module#
- class stable_pretraining.callbacks.wd_schedule.WeightDecayUpdater(schedule_type: str = 'cosine', start_value: float = 0.01, end_value: float = 0.0, param_group_indices: list = None, opt_idx: int = None)[source]#
Bases:
CallbackPyTorch Lightning Callback to update optimizer’s weight decay per batch.
Supports multiple schedules: ‘constant’, ‘linear’, ‘cosine’, ‘exponential’
Optionally specify which optimizer param group(s) to update (by index)
Infers total steps from Trainer config (max_steps or max_epochs + dataloader)
Checkpointable: state is saved/restored with Trainer checkpoints
Extensive Loguru logging
- Parameters:
schedule_type (str) – One of ‘constant’, ‘linear’, ‘cosine’, ‘exponential’
start_value (float) – Initial weight decay value
end_value (float) – Final weight decay value (for non-constant schedules)
param_group_indices (list[int] or None) – List of param group indices to update. If None, updates all.
- load_state_dict(state_dict)[source]#
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict.- Parameters:
state_dict – the callback state returned by
state_dict.
stable_pretraining.callbacks.writer module#
- class stable_pretraining.callbacks.writer.OnlineWriter(names: str, path: str | Path, during: str | list[str], every_k_epochs: int = -1, save_last_epoch: bool = False, save_sanity_check: bool = False, all_gather: bool = True)[source]#
Bases:
CallbackWrites specified batch data to disk during training and validation.
This callback enables selective saving of batch data (e.g., features, predictions, embeddings) to disk at specified intervals during training. It’s useful for debugging, visualization, and analysis of model behavior during training.
Features: - Flexible saving schedule (every k epochs, last epoch, sanity check) - Support for distributed training with optional all_gather - Automatic directory creation - Configurable for different training phases (train, val, test)
- Parameters:
names – Name(s) of the batch keys to save. Can be string or list of strings.
path – Directory path where files will be saved.
during – Training phase(s) when to save (‘train’, ‘val’, ‘test’, or list).
every_k_epochs – Save every k epochs. -1 means every epoch.
save_last_epoch – Whether to save on the last training epoch.
save_sanity_check – Whether to save during sanity check phase.
all_gather – Whether to gather data across all distributed processes.
Files are saved with naming pattern: {phase}_{name}_epoch{epoch}_batch{batch}.pt
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the predict batch ends.
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the test batch ends.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the train batch ends.
Note
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.