User Guide

User Guide#

This guide provides instructions for launching runs with stable-SSL.

To make the process streamlined and efficient, we recommend using configuration files to define parameters and utilizing Hydra to manage these configurations.

General Idea. stable-SSL provides a highly flexible framework with minimal hardcoded utilities. Modules in the pipeline can instantiate objects from various sources, including stable-SSL, PyTorch, TorchMetrics, or even custom objects provided by the user. This allows you to seamlessly integrate your own components into the pipeline while leveraging the capabilities of stable-SSL.

trainer#

In stable-SSL, the main trainer object must inherit from the BaseTrainer class. This class serves as the primary entry point for the training loop and provides all the essential methods required to train and evaluate your model effectively.

BaseTrainer(data, module, hardware, optim, ...)

Base class for training a model.

stable_ssl.trainers provides default trainer classes for various self-supervised learning approaches.

Here is what instantiating an SSL trainer class from stable_ssl.trainers looks like in the YAML configuration file:

_target_: stable_ssl.trainers.JointEmbeddingTrainer

loss#

The loss keyword is used to define the loss function for your model.

stable_ssl.losses offers a variety of loss functions for SSL.

Here’s an example of how to define the loss section in your YAML file:

loss:
   _target_: stable_ssl.losses.NTXEntLoss
   temperature: 0.5

data#

The data keyword specifies the settings for data loading, preprocessing, and data augmentation. Multiple datasets can be defined, with the dataset named train used for training. Other datasets, which can have any name, are used for evaluation purposes.

Example:

data:
   _num_classes: 10
   _num_samples: 50000
   train: # name 'train' indicates that this dataset should be used for training
      _target_: torch.utils.data.DataLoader
      batch_size: 256
      drop_last: True
      shuffle: True
      num_workers: ${trainer.hardware.cpus_per_task}
      dataset:
         _target_: torchvision.datasets.CIFAR10
         root: ~/data
         train: True
         transform:
            _target_: stable_ssl.data.MultiViewSampler
            transforms:
            - _target_: torchvision.transforms.v2.Compose
               transforms:
                  - _target_: torchvision.transforms.v2.RandomResizedCrop
                  size: 32
                  scale:
                     - 0.2
                     - 1.0
                  - _target_: torchvision.transforms.v2.RandomHorizontalFlip
                  p: 0.5
                  - _target_: torchvision.transforms.v2.ToImage
                  - _target_: torchvision.transforms.v2.ToDtype
                  dtype:
                     _target_: stable_ssl.utils.str_to_dtype
                     _args_: [float32]
                  scale: True
            - ${trainer.data.base.dataset.transform.transforms.0}
   test:
      _target_: torch.utils.data.DataLoader
      batch_size: 256
      num_workers: ${trainer.hardware.cpus_per_task}
      dataset:
         _target_: torchvision.datasets.CIFAR10
         train: False
         root: ~/data
         transform:
            _target_: torchvision.transforms.v2.Compose
            transforms:
            - _target_: torchvision.transforms.v2.ToImage
            - _target_: torchvision.transforms.v2.ToDtype
               dtype:
                  _target_: stable_ssl.utils.str_to_dtype
                  _args_: [float32]
               scale: True

module#

The module keyword is used to define the settings of all the neural networks used, including the architecture of the backbone, projectors etc.

stable_ssl.modules provides a variety of utility functions that can be used to load specific architectures and pre-trained models.

Example:

module:
   backbone:
      _target_: stable_ssl.modules.load_backbone
      name: resnet18
      low_resolution: True
      num_classes: null
   projector:
      _target_: torch.nn.Sequential
      _args_:
         - _target_: torch.nn.Linear
            in_features: 512
            out_features: 2048
            bias: False
         - _target_: torch.nn.BatchNorm1d
            num_features: ${trainer.module.projector._args_.0.out_features}
         - _target_: torch.nn.ReLU
         - _target_: torch.nn.Linear
            in_features: ${trainer.module.projector._args_.0.out_features}
            out_features: 128
            bias: False
         - _target_: torch.nn.BatchNorm1d
            num_features: ${trainer.module.projector._args_.3.out_features}
   projector_classifier:
      _target_: torch.nn.Linear
      in_features: 128
      out_features: ${trainer.data._num_classes}
   backbone_classifier:
      _target_: torch.nn.Linear
      in_features: 512
      out_features: ${trainer.data._num_classes}

The various components defined above can be accessed through the dictionary self.module in your trainer class. This allows the user to define the forward pass, compute losses, and specify evaluation metrics efficiently.

optim#

The optim keyword is used to define the optimization settings for your model. It allows users to specify both the optimizer object and the scheduler.

The default parameters associated with the optim keyword are defined in the following:

config.OptimConfig(optimizer, scheduler[, ...])

Configuration for the optimization parameters.

stable_ssl.optimizers and stable_ssl.schedulers provide additional modules that are not available in PyTorch.

Example:

optim:
 epochs: 1000
 optimizer:
   _target_: stable_ssl.optimizers.LARS
   _partial_: True
   lr: 5
   weight_decay: 1e-6
 scheduler:
   _target_: stable_ssl.scheduler.LinearWarmupCosineAnnealing
   _partial_: True
   total_steps: ${eval:'${trainer.optim.epochs} * ${trainer.data._num_samples} // ${trainer.data.train.batch_size}'}

logger#

The logger keyword is used to configure the logging settings for your run.

One important section is metrics, which lets you define the evaluation metrics to track during training. Metrics can be specified for each dataset.

The default parameters associated with logger are defined in the following:

config.LoggerConfig(level, metric, monitor, ...)

Configuration for logging and checkpointing during training or evaluation.

Example:

logger:
   base_dir: "./"
   level: 20
   checkpoint_frequency: 1
   log_every_step: 1
   metric:
      test:
         acc1:
         _target_: torchmetrics.classification.MulticlassAccuracy
         num_classes: ${trainer.data._num_classes}
         top_k: 1
         acc5:
         _target_: torchmetrics.classification.MulticlassAccuracy
         num_classes: ${trainer.data._num_classes}
         top_k: 5

hardware#

Use the hardware keyword to configure hardware-related settings such as device, world_size (number of GPUs) or CPUs per task.

The default parameters associated with hardware are defined in the following:

config.HardwareConfig([seed, float16, ...])

Configuration for the hardware parameters.

Example:

hardware:
   seed: 0
   float16: true
   device: "cuda:0"
   world_size: 1