stable_pretraining.forward#

Forward functions define the core training logic for different self-supervised learning methods. They are called during the forward pass of the Module and handle how data flows through the model, compute losses, and return outputs.

Overview#

Forward functions are the heart of each SSL method implementation. They:

  • Take a batch of data and training stage as input

  • Process data through backbone and projection heads

  • Compute method-specific losses during training

  • Return a dictionary containing loss and embeddings

The forward function is bound to the Module instance at initialization, giving it access to all module attributes (backbone, projector, loss functions, etc.).

Usage in Config#

Forward functions can be specified in YAML configs as string references:

module:
  _target_: stable_pretraining.Module
  forward: stable_pretraining.forward.simclr_forward
  backbone: ...
  projector: ...
  simclr_loss: ...

Or in Python code:

from stable_pretraining import Module
from stable_pretraining.forward import simclr_forward

module = Module(
    forward=simclr_forward,
    backbone=backbone,
    projector=projector,
    simclr_loss=loss_fn
)

Available Forward Functions#

SimCLR#

stable_pretraining.forward.simclr_forward(self, batch, stage)[source]#

Forward function for SimCLR (Simple Contrastive Learning of Representations).

SimCLR learns representations by maximizing agreement between differently augmented views of the same image via a contrastive loss in the latent space.

Parameters:
  • self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head mapping features to latent space - simclr_loss: NT-Xent contrastive loss function

  • batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image

  • stage – Training stage (‘train’, ‘val’, or ‘test’)

Returns:

  • ‘embedding’: Feature representations from backbone

  • ’loss’: NT-Xent contrastive loss (during training only)

Return type:

Dictionary containing

Note

Introduced in the SimCLR paper [Chen et al., 2020].

Required Module Attributes:

  • backbone: Feature extraction network

  • projector: Projection head for embedding transformation

  • simclr_loss: NTXent contrastive loss function

Example Config:

module:
  forward: stable_pretraining.forward.simclr_forward
  backbone:
    _target_: stable_pretraining.backbone.from_torchvision
    model_name: resnet50
  projector:
    _target_: torch.nn.Sequential
    _args_:
      - _target_: torch.nn.Linear
        in_features: 2048
        out_features: 2048
      - _target_: torch.nn.ReLU
      - _target_: torch.nn.Linear
        in_features: 2048
        out_features: 128
  simclr_loss:
    _target_: stable_pretraining.losses.NTXEntLoss
    temperature: 0.5

NNCLR#

stable_pretraining.forward.nnclr_forward(self, batch, stage)[source]#

Forward function for NNCLR (Nearest-Neighbor Contrastive Learning).

NNCLR learns representations by using the nearest neighbor of an augmented view from a support set of past embeddings as a positive pair. This encourages the model to learn representations that are similar for semantically similar instances, not just for different augmentations of the same instance.

Parameters:
  • self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head for embedding transformation - predictor: Prediction head used for the online view - nnclr_loss: NTXent contrastive loss function

  • batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image

  • stage – Training stage (‘train’, ‘val’, or ‘test’)

Returns:

  • ‘embedding’: Feature representations from backbone

  • ’loss’: NTXent contrastive loss (during training only)

  • ’nnclr_support_set’: Projections to be added to the support set queue

Return type:

Dictionary containing

Note

Introduced in the NNCLR paper [Dwibedi et al., 2021].

Required Module Attributes:

  • backbone: Feature extraction network

  • projector: Projection head for embedding transformation

  • predictor: Prediction head for the online network

  • nnclr_loss: NTXent contrastive loss function

Key Features:

  • Uses a support set of past embeddings to find nearest-neighbor positives.

  • Encourages semantic similarity, going beyond instance-level discrimination.

  • Requires an OnlineQueue callback with a matching key.

Example Config:

module:
  forward: stable_pretraining.forward.nnclr_forward
  backbone:
    _target_: stable_pretraining.backbone.from_torchvision
    model_name: resnet18
  projector:
    _target_: torch.nn.Sequential
    _args_:
      - _target_: torch.nn.Linear
        in_features: 512
        out_features: 2048
      - _target_: torch.nn.BatchNorm1d
        num_features: 2048
      - _target_: torch.nn.ReLU
      - _target_: torch.nn.Linear
        in_features: 2048
        out_features: 256
  predictor:
    _target_: torch.nn.Sequential
    _args_:
      - _target_: torch.nn.Linear
        in_features: 256
        out_features: 4096
      - _target_: torch.nn.BatchNorm1d
        num_features: 4096
      - _target_: torch.nn.ReLU
      - _target_: torch.nn.Linear
        in_features: 4096
        out_features: 256
  nnclr_loss:
    _target_: stable_pretraining.losses.NTXEntLoss
    temperature: 0.5
  hparams:
    support_set_size: 16384
    projection_dim: 256

callbacks:
  - _target_: stable_pretraining.callbacks.OnlineQueue
    key: nnclr_support_set
    queue_length: ${module.hparams.support_set_size}
    dim: ${module.hparams.projection_dim}

BYOL#

stable_pretraining.forward.byol_forward(self, batch, stage)[source]#

Forward function for BYOL (Bootstrap Your Own Latent).

BYOL learns representations without negative pairs by using a momentum-based target network and predicting target projections from online projections.

Parameters:
  • self – Module instance with required attributes: - backbone: TeacherStudentWrapper for feature extraction - projector: TeacherStudentWrapper for projection head - predictor: Online network predictor - byol_loss: BYOL loss function (optional, uses MSE if not provided)

  • batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image

  • stage – Training stage (‘train’, ‘val’, or ‘test’)

Returns:

  • ‘embedding’: Feature representations from teacher backbone (EMA target)

  • ’loss’: BYOL loss between predictions and targets (during training)

Return type:

Dictionary containing

Note

Introduced in the BYOL paper [Grill et al., 2020].

Required Module Attributes:

  • backbone: Online network backbone

  • projector: Online network projector

  • predictor: Online network predictor

  • target_backbone: Target network backbone (momentum encoder)

  • target_projector: Target network projector

Key Features:

  • Uses momentum encoder for target network

  • No negative pairs required

  • MSE loss between predictions and targets

Example Config:

module:
  forward: stable_pretraining.forward.byol_forward
  backbone: ...
  projector: ...
  predictor:
    _target_: torch.nn.Sequential
    _args_:
      - _target_: torch.nn.Linear
        in_features: 256
        out_features: 4096
      - _target_: torch.nn.BatchNorm1d
        num_features: 4096
      - _target_: torch.nn.ReLU
      - _target_: torch.nn.Linear
        in_features: 4096
        out_features: 256

VICReg#

stable_pretraining.forward.vicreg_forward(self, batch, stage)[source]#

Forward function for VICReg (Variance-Invariance-Covariance Regularization).

VICReg learns representations using three criteria: variance (maintaining information), invariance (to augmentations), and covariance (decorrelating features).

Parameters:
  • self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head for embedding transformation - vicreg_loss: VICReg loss with variance, invariance, covariance terms

  • batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image

  • stage – Training stage (‘train’, ‘val’, or ‘test’)

Returns:

  • ‘embedding’: Feature representations from backbone

  • ’loss’: Combined VICReg loss (during training only)

Return type:

Dictionary containing

Note

Introduced in the VICReg paper [].

Required Module Attributes:

  • backbone: Feature extraction network

  • projector: Projection head

  • vicreg_loss: VICReg loss (variance + invariance + covariance)

Key Features:

  • Variance regularization to maintain information

  • Invariance to augmentations

  • Covariance regularization to decorrelate features

Example Config:

module:
  forward: stable_pretraining.forward.vicreg_forward
  backbone: ...
  projector: ...
  vicreg_loss:
    _target_: stable_pretraining.losses.VICRegLoss
    sim_weight: 25.0
    var_weight: 25.0
    cov_weight: 1.0

Barlow Twins#

stable_pretraining.forward.barlow_twins_forward(self, batch, stage)[source]#

Forward function for Barlow Twins.

Barlow Twins learns representations by making the cross-correlation matrix between embeddings of augmented views as close to the identity matrix as possible, reducing redundancy while maintaining invariance.

Parameters:
  • self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - projector: Projection head (typically with BN and high dimension) - barlow_loss: Barlow Twins loss function

  • batch – Input batch dictionary containing: - ‘image’: Tensor of augmented images [N*views, C, H, W] - ‘sample_idx’: Indices to identify views of same image

  • stage – Training stage (‘train’, ‘val’, or ‘test’)

Returns:

  • ‘embedding’: Feature representations from backbone

  • ’loss’: Barlow Twins loss (during training only)

Return type:

Dictionary containing

Note

Introduced in the Barlow Twins paper [Zbontar et al., 2021].

Required Module Attributes:

  • backbone: Feature extraction network

  • projector: Projection head

  • barlow_loss: Barlow Twins loss function

Key Features:

  • Reduces redundancy between embedding components

  • Makes cross-correlation matrix close to identity

  • No negative pairs or momentum encoder needed

Example Config:

module:
  forward: stable_pretraining.forward.barlow_twins_forward
  backbone: ...
  projector: ...
  barlow_loss:
    _target_: stable_pretraining.losses.BarlowTwinsLoss
    lambda_: 0.005

Supervised#

stable_pretraining.forward.supervised_forward(self, batch, stage)[source]#

Forward function for standard supervised training.

This function implements traditional supervised learning with labels, useful for baseline comparisons and fine-tuning pre-trained models.

Parameters:
  • self – Module instance (automatically bound) with required attributes: - backbone: Feature extraction network - classifier: Classification head (e.g., Linear layer)

  • batch – Input batch dictionary containing: - ‘image’: Tensor of images [N, C, H, W] - ‘label’: Ground truth labels [N] (optional, for loss computation)

  • stage – Training stage (‘train’, ‘val’, or ‘test’)

Returns:

  • ‘embedding’: Feature representations from backbone

  • ’logits’: Classification predictions

  • ’loss’: Cross-entropy loss (if labels provided)

Return type:

Dictionary containing

Note

Unlike SSL methods, this function uses actual labels for training and is primarily used for evaluation or supervised baselines.

Required Module Attributes:

  • backbone: Feature extraction network

  • classifier: Classification head

Key Features:

  • Standard supervised learning

  • Cross-entropy loss for classification

  • Useful for baseline comparisons

Example Config:

module:
  forward: stable_pretraining.forward.supervised_forward
  backbone: ...
  classifier:
    _target_: torch.nn.Linear
    in_features: 2048
    out_features: 1000

Custom Forward Functions#

You can create custom forward functions for new SSL methods:

def custom_ssl_forward(self, batch, stage):
    """Custom SSL method forward function.

    Args:
        self: Module instance with access to all attributes
        batch: Dict containing 'image' and other data
        stage: One of 'train', 'val', or 'test'

    Returns:
        Dict with at least 'loss' key during training
    """
    out = {}

    # Extract features
    out["embedding"] = self.backbone(batch["image"])

    if self.training:
        # Your custom SSL logic here
        proj = self.projector(out["embedding"])

        # Compute custom loss
        out["loss"] = self.custom_loss(proj)

    return out

Requirements for Custom Functions:

  1. Signature: Must accept (self, batch, stage)

  2. Return: Dictionary with "loss" key during training (this is the only hardcoded requirement)

  3. Training Mode: Check self.training or stage == "train"

  4. Outputs: You can return any other keys you want (embeddings, projections, logits, etc.) with any names you choose

Important Note on Output Keys:

The keys you use in the output dictionary (like "embedding", "logits", etc.) are not hardcoded requirements, but they serve as references for callbacks. For example:

  • If you return out["embedding"], callbacks can access it via outputs["embedding"]

  • If you return out["features"], callbacks would access outputs["features"]

  • The OnlineProbe callback expects its input parameter to match one of your output keys

This allows flexible integration between your forward function and various callbacks.

Example: Connecting Forward Outputs to Callbacks

module:
  forward: my_custom_forward
  # Forward function returns: {"loss": ..., "my_features": ..., "my_projection": ...}

callbacks:
  - _target_: stable_pretraining.callbacks.OnlineProbe
    name: probe1
    input: my_features  # References the "my_features" key from forward output
    target: label

  - _target_: stable_pretraining.callbacks.OnlineKNN
    name: knn1
    input: my_projection  # References the "my_projection" key from forward output
    target: label

The callback’s input parameter must match the key name you chose in your forward function

Integration with Module#

The forward function becomes the forward method of the Module:

class Module(pl.LightningModule):
    def __init__(self, forward, **kwargs):
        # Bind the forward function to this instance
        self.forward = forward.__get__(self, self.__class__)

    def training_step(self, batch, batch_idx):
        # Forward function is called here
        state = self(batch, stage="train")
        return state["loss"]

This design allows maximum flexibility while keeping the implementation clean and modular.

See Also#