stable_ssl.module#
The module provides the main PyTorch Lightning module class for self-supervised learning. This is the core component that handles all training orchestration - you only need to implement the forward method!
Core Module#
|
PyTorch Lightning module with callbacks support, user-defined forward and metrics. |
User Implementation#
The key insight of stable-ssl is simplicity: you only need to implement the forward method. Everything else (optimizers, schedulers, training loops, logging) is handled automatically.
Required Implementation:
def forward(self, batch, stage):
# Your custom logic here
batch["embedding"] = self.backbone(batch["image"])
if self.training:
# Training-specific logic
proj = self.projector(batch["embedding"])
views = ssl.data.fold_views(proj, batch["sample_idx"])
batch["loss"] = self.simclr_loss(views[0], views[1])
return batch
Module Creation:
module = ssl.Module(
backbone=backbone, # Your model components
projector=projector, # Any kwargs become self.attributes
forward=forward, # Your forward function
simclr_loss=ssl.losses.NTXEntLoss(temperature=0.1),
)
What’s Handled Automatically:
✅ Optimizer Configuration: Default AdamW with CosineAnnealingLR scheduler
✅ Training Loop: Automatic gradient accumulation, clipping, and stepping
✅ Stage Management: Training/validation/test/predict stages
✅ Metrics: Automatic metric logging and computation
✅ Callbacks: Integration with all stable-ssl callbacks
✅ Logging: Rich logging and monitoring
Key Features:
Dictionary-based: Input and output are dictionaries for maximum flexibility
Stage-aware: The stage parameter tells you if you’re in training/validation/test/predict
Loss-driven: Include “loss” key for training, omit for evaluation-only
Automatic optimization: Set optim=False if you don’t need training
Flexible components: Any kwargs become module attributes accessible via self
Example Use Cases:
Self-supervised learning: Implement contrastive losses
Supervised learning: Standard classification/regression
Multi-task learning: Multiple losses and outputs
Evaluation-only: No loss key for inference
The Module class is designed to be the only component you need to understand for implementing any SSL algorithm! 🎯