stable_pretraining.forward#
Forward functions define the core training logic for different self-supervised learning methods. They are called during the forward pass of the Module
and handle how data flows through the model, compute losses, and return outputs.
Overview#
Forward functions are the heart of each SSL method implementation. They:
Take a batch of data and training stage as input
Process data through backbone and projection heads
Compute method-specific losses during training
Return a dictionary containing loss and embeddings
The forward function is bound to the Module instance at initialization, giving it access to all module attributes (backbone, projector, loss functions, etc.).
Usage in Config#
Forward functions can be specified in YAML configs as string references:
module:
_target_: stable_pretraining.Module
forward: stable_pretraining.forward.simclr_forward
backbone: ...
projector: ...
simclr_loss: ...
Or in Python code:
from stable_pretraining import Module
from stable_pretraining.forward import simclr_forward
module = Module(
forward=simclr_forward,
backbone=backbone,
projector=projector,
simclr_loss=loss_fn
)
Available Forward Functions#
SimCLR#
- stable_pretraining.forward.simclr_forward(self, batch, stage)[source]#
Forward function for SimCLR (Simple Contrastive Learning of Representations).
SimCLR learns representations by maximizing agreement between differently augmented views of the same image via a contrastive loss in the latent space.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head mapping features to latent space - simclr_loss: NT-Xent contrastive loss function
batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: NT-Xent contrastive loss (during training only)
- Return type:
Dictionary containing
Note
Introduced in the SimCLR paper [Chen et al., 2020].
Required Module Attributes:
backbone
: Feature extraction networkprojector
: Projection head for embedding transformationsimclr_loss
: NTXent contrastive loss function
Example Config:
module:
forward: stable_pretraining.forward.simclr_forward
backbone:
_target_: stable_pretraining.backbone.from_torchvision
model_name: resnet50
projector:
_target_: torch.nn.Sequential
_args_:
- _target_: torch.nn.Linear
in_features: 2048
out_features: 2048
- _target_: torch.nn.ReLU
- _target_: torch.nn.Linear
in_features: 2048
out_features: 128
simclr_loss:
_target_: stable_pretraining.losses.NTXEntLoss
temperature: 0.5
NNCLR#
- stable_pretraining.forward.nnclr_forward(self, batch, stage)[source]#
Forward function for NNCLR (Nearest-Neighbor Contrastive Learning).
NNCLR learns representations by using the nearest neighbor of an augmented view from a support set of past embeddings as a positive pair. This encourages the model to learn representations that are similar for semantically similar instances, not just for different augmentations of the same instance.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head for embedding transformation - predictor: Prediction head used for the online view - nnclr_loss: NTXent contrastive loss function
batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: NTXent contrastive loss (during training only)
’nnclr_support_set’: Projections to be added to the support set queue
- Return type:
Dictionary containing
Note
Introduced in the NNCLR paper [Dwibedi et al., 2021].
Required Module Attributes:
backbone
: Feature extraction networkprojector
: Projection head for embedding transformationpredictor
: Prediction head for the online networknnclr_loss
: NTXent contrastive loss function
Key Features:
Uses a support set of past embeddings to find nearest-neighbor positives.
Encourages semantic similarity, going beyond instance-level discrimination.
Requires an
OnlineQueue
callback with a matchingkey
.
Example Config:
module:
forward: stable_pretraining.forward.nnclr_forward
backbone:
_target_: stable_pretraining.backbone.from_torchvision
model_name: resnet18
projector:
_target_: torch.nn.Sequential
_args_:
- _target_: torch.nn.Linear
in_features: 512
out_features: 2048
- _target_: torch.nn.BatchNorm1d
num_features: 2048
- _target_: torch.nn.ReLU
- _target_: torch.nn.Linear
in_features: 2048
out_features: 256
predictor:
_target_: torch.nn.Sequential
_args_:
- _target_: torch.nn.Linear
in_features: 256
out_features: 4096
- _target_: torch.nn.BatchNorm1d
num_features: 4096
- _target_: torch.nn.ReLU
- _target_: torch.nn.Linear
in_features: 4096
out_features: 256
nnclr_loss:
_target_: stable_pretraining.losses.NTXEntLoss
temperature: 0.5
hparams:
support_set_size: 16384
projection_dim: 256
callbacks:
- _target_: stable_pretraining.callbacks.OnlineQueue
key: nnclr_support_set
queue_length: ${module.hparams.support_set_size}
dim: ${module.hparams.projection_dim}
BYOL#
- stable_pretraining.forward.byol_forward(self, batch, stage)[source]#
Forward function for BYOL (Bootstrap Your Own Latent).
BYOL learns representations without negative pairs by using a momentum-based target network and predicting target projections from online projections.
- Parameters:
self – Module instance with required attributes: - backbone: TeacherStudentWrapper for feature extraction - projector: TeacherStudentWrapper for projection head - predictor: Online network predictor - byol_loss: BYOL loss function (optional, uses MSE if not provided)
batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from teacher backbone (EMA target)
’loss’: BYOL loss between predictions and targets (during training)
- Return type:
Dictionary containing
Note
Introduced in the BYOL paper [Grill et al., 2020].
Required Module Attributes:
backbone
: Online network backboneprojector
: Online network projectorpredictor
: Online network predictortarget_backbone
: Target network backbone (momentum encoder)target_projector
: Target network projector
Key Features:
Uses momentum encoder for target network
No negative pairs required
MSE loss between predictions and targets
Example Config:
module:
forward: stable_pretraining.forward.byol_forward
backbone: ...
projector: ...
predictor:
_target_: torch.nn.Sequential
_args_:
- _target_: torch.nn.Linear
in_features: 256
out_features: 4096
- _target_: torch.nn.BatchNorm1d
num_features: 4096
- _target_: torch.nn.ReLU
- _target_: torch.nn.Linear
in_features: 4096
out_features: 256
VICReg#
- stable_pretraining.forward.vicreg_forward(self, batch, stage)[source]#
Forward function for VICReg (Variance-Invariance-Covariance Regularization).
VICReg learns representations using three criteria: variance (maintaining information), invariance (to augmentations), and covariance (decorrelating features).
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head for embedding transformation - vicreg_loss: VICReg loss with variance, invariance, covariance terms
batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: Combined VICReg loss (during training only)
- Return type:
Dictionary containing
Note
Introduced in the VICReg paper [].
Required Module Attributes:
backbone
: Feature extraction networkprojector
: Projection headvicreg_loss
: VICReg loss (variance + invariance + covariance)
Key Features:
Variance regularization to maintain information
Invariance to augmentations
Covariance regularization to decorrelate features
Example Config:
module:
forward: stable_pretraining.forward.vicreg_forward
backbone: ...
projector: ...
vicreg_loss:
_target_: stable_pretraining.losses.VICRegLoss
sim_weight: 25.0
var_weight: 25.0
cov_weight: 1.0
Barlow Twins#
- stable_pretraining.forward.barlow_twins_forward(self, batch, stage)[source]#
Forward function for Barlow Twins.
Barlow Twins learns representations by making the cross-correlation matrix between embeddings of augmented views as close to the identity matrix as possible, reducing redundancy while maintaining invariance.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head (typically with BN and high dimension) - barlow_loss: Barlow Twins loss function
batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’loss’: Barlow Twins loss (during training only)
- Return type:
Dictionary containing
Note
Introduced in the Barlow Twins paper [Zbontar et al., 2021].
Required Module Attributes:
backbone
: Feature extraction networkprojector
: Projection headbarlow_loss
: Barlow Twins loss function
Key Features:
Reduces redundancy between embedding components
Makes cross-correlation matrix close to identity
No negative pairs or momentum encoder needed
Example Config:
module:
forward: stable_pretraining.forward.barlow_twins_forward
backbone: ...
projector: ...
barlow_loss:
_target_: stable_pretraining.losses.BarlowTwinsLoss
lambda_: 0.005
Supervised#
- stable_pretraining.forward.supervised_forward(self, batch, stage)[source]#
Forward function for standard supervised training.
This function implements traditional supervised learning with labels, useful for baseline comparisons and fine-tuning pre-trained models.
- Parameters:
self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - classifier: Classification head (e.g., Linear layer)
batch – Input batch dictionary containing: - ‘image’: Tensor of images [N, C, H, W] - ‘label’: Ground truth labels [N] (optional, for loss computation)
stage – Training stage (‘train’, ‘val’, or ‘test’)
- Returns:
‘embedding’: Feature representations from backbone
’logits’: Classification predictions
’loss’: Cross-entropy loss (if labels provided)
- Return type:
Dictionary containing
Note
Unlike SSL methods, this function uses actual labels for training and is primarily used for evaluation or supervised baselines.
Required Module Attributes:
backbone
: Feature extraction networkclassifier
: Classification head
Key Features:
Standard supervised learning
Cross-entropy loss for classification
Useful for baseline comparisons
Example Config:
module:
forward: stable_pretraining.forward.supervised_forward
backbone: ...
classifier:
_target_: torch.nn.Linear
in_features: 2048
out_features: 1000
Custom Forward Functions#
You can create custom forward functions for new SSL methods:
def custom_ssl_forward(self, batch, stage):
"""Custom SSL method forward function.
Args:
self: Module instance with access to all attributes
batch: Dict containing 'image' and other data
stage: One of 'train', 'val', or 'test'
Returns:
Dict with at least 'loss' key during training
"""
out = {}
# Extract features
out["embedding"] = self.backbone(batch["image"])
if self.training:
# Your custom SSL logic here
proj = self.projector(out["embedding"])
# Compute custom loss
out["loss"] = self.custom_loss(proj)
return out
Requirements for Custom Functions:
Signature: Must accept
(self, batch, stage)
Return: Dictionary with
"loss"
key during training (this is the only hardcoded requirement)Training Mode: Check
self.training
orstage == "train"
Outputs: You can return any other keys you want (embeddings, projections, logits, etc.) with any names you choose
Important Note on Output Keys:
The keys you use in the output dictionary (like "embedding"
, "logits"
, etc.) are not hardcoded requirements, but they serve as references for callbacks. For example:
If you return
out["embedding"]
, callbacks can access it viaoutputs["embedding"]
If you return
out["features"]
, callbacks would accessoutputs["features"]
The OnlineProbe callback expects its
input
parameter to match one of your output keys
This allows flexible integration between your forward function and various callbacks.
Example: Connecting Forward Outputs to Callbacks
module:
forward: my_custom_forward
# Forward function returns: {"loss": ..., "my_features": ..., "my_projection": ...}
callbacks:
- _target_: stable_pretraining.callbacks.OnlineProbe
name: probe1
input: my_features # References the "my_features" key from forward output
target: label
- _target_: stable_pretraining.callbacks.OnlineKNN
name: knn1
input: my_projection # References the "my_projection" key from forward output
target: label
The callback’s input
parameter must match the key name you chose in your forward function
Integration with Module#
The forward function becomes the forward
method of the Module:
class Module(pl.LightningModule):
def __init__(self, forward, **kwargs):
# Bind the forward function to this instance
self.forward = forward.__get__(self, self.__class__)
def training_step(self, batch, batch_idx):
# Forward function is called here
state = self(batch, stage="train")
return state["loss"]
This design allows maximum flexibility while keeping the implementation clean and modular.
See Also#
stable_pretraining.losses - Loss functions used by forward functions
stable_pretraining.module - The Module class that uses forward functions
stable_pretraining.data - Data utilities including
fold_views