BaseTrainer#

class stable_ssl.BaseTrainer(data, module, hardware, optim, logger, loss=None, **kwargs)[source]#

Bases: Module

Base class for training a model.

This class provides a general boilerplate for common operations that occur during the training lifecycle. These operations include training, evaluation, checkpointing, and training restart.

The class is highly configurable, enabling customization of its internal workflows to suit diverse project requirements and use cases.

This class is intended to be subclassed for specific training methods (see examples for more details). For each subclass, the following methods must be implemented: forward, predict (used for supervised evaluation) and compute_loss (used for training).

Execution flow when calling launch:
  • self.before_fit (nothing by default)

  • self._fit (executes all the training/intermittent evaluation by default)
    • for self.optim[“epochs”] epochs:
      • self.before_fit_epoch (setup in train mode)

      • self._fit_epoch (one training epoch by default)
        • loop over mini-batches
          • self.before_fit_step (moves data to device)

          • self._fit_step (optimization step)

          • self.after_fit_step (perf monitoring and teacher update)

      • self.after_fit_epoch (nothing by default)

      • self._evaluate (if asked, looping over all non-train datasets)
        • self.before_eval (setup in eval mode)

        • loop over mini-batches
          • self.before_eval_step (moves data to device)

          • self._eval_step (computes eval metric)

          • self.after_eval_step (nothing by default)

        • self.after_eval (nothing by default)

      • Save intermittent checkpoint if asked by user config

    • Save final checkpoint if asked by user config

  • self.after_fit (evaluates by default)

Parameters:
  • data (dict) – Names and construction of the dataloaders with their transform pipelines. The dataset named train is used for training. Any other dataset is used for validation.

  • module (dict) – Names and definition of the modules (neural networks). See stable_ssl.modules for examples of available modules.

  • hardware (dict) – Hardware parameters. See stable_ssl.config.HardwareConfig for the full list of parameters and their defaults.

  • optim (dict) – Optimization parameters. See stable_ssl.config.OptimConfig for the full list of parameters and their defaults.

  • logger (dict) – Logging and checkpointing parameters. See stable_ssl.config.LoggerConfig for the full list of parameters and their defaults.

  • loss (dict, optional) – Loss function used in the final criterion to be minimized. See stable_ssl.losses for examples. Defaults to None.

  • **kwargs – Additional arguments to be set as attributes of the class.

after_eval()[source]#

Handle tasks after completing evaluation (currently does nothing).

after_eval_step()[source]#

Handle post-step tasks after an evaluation step (currently does nothing).

after_fit()[source]#

Evaluate the model after completing the training process.

after_fit_epoch()[source]#

Handle post-epoch tasks after training (currently does nothing).

after_fit_step()[source]#

Handle per-step monitoring and teacher update (if applicable).

before_eval()[source]#

Set the model to evaluation mode before validation/testing.

before_eval_step()[source]#

Prepare batch for evaluation step by moving it to the appropriate device.

before_fit()[source]#

Initialize training by setting the starting epoch.

before_fit_epoch()[source]#

Prepare the training state and set the epoch for distributed training.

before_fit_step()[source]#

Prepare batch for training step by moving it to the appropriate device.

checkpoint() DelayedSubmission[source]#

Create a checkpoint of the current state of the model.

This method is called asynchronously when the SLURM manager sends a preemption signal. It is invoked with the same arguments as the __call__ method. At this point, self represents the current state of the model.

Returns:

submitit.helpers.DelayedSubmission – representing the requeued task with the current model state.

Return type:

A delayed submission object

clean()[source]#

Delete the working directory with logs.

abstract compute_loss()[source]#

Calculate the global loss to be minimized during training.

Compute the total loss that the model aims to minimize. Implementations can utilize the loss function provided during the trainer’s initialization to calculate loss based on the current batch.

Note that it can return a list or dictionary of losses. The various losses are logged independently and summed to compute the final loss.

See Also:

stable_ssl.trainers for concrete examples of implementations.

abstract forward()[source]#

Forward pass of the model.

get_config()[source]#

Retrieve the configuration file of the trainer.

get_logs(keys=None, min_step=0, max_step=-1)[source]#

Retrieve the logs from the logger.

get_project_logs(keys=None, state=['finished'])[source]#

Retrieve the project logs from the logger.

launch()[source]#

Execute the core training and evaluation routine.

This method runs the training and evaluation process, with a customizable boilerplate structure.

The default flow includes: - Running evaluation and cleanup if no “train” dataset is found in self.data. - Otherwise performing pre-training, training, and post-training tasks.

Exceptions#

BreakAllEpochs

Raised if the training is interrupted by the user.

abstract predict()[source]#

Generate model predictions for evaluation purposes.

Supervised and Self-Supervised models are typically evaluated using predictions over discrete labels. This method should return the output of this classification used for evaluation.

In SSL, this typically involves using a classifier head on top of the backbone, thus turning the SSL model into a supervised model for evaluation.

See Also:

stable_ssl.trainers for concrete examples of implementations.

setup()[source]#

Instantiate components and load the checkpoint (if applicable).