BaseTrainer#
- class stable_ssl.BaseTrainer(data, module, hardware, optim, logger, loss=None, **kwargs)[source]#
Bases:
Module
Base class for training a model.
This class provides a general boilerplate for common operations that occur during the training lifecycle. These operations include training, evaluation, checkpointing, and training restart.
The class is highly configurable, enabling customization of its internal workflows to suit diverse project requirements and use cases.
This class is intended to be subclassed for specific training methods (see examples for more details). For each subclass, the following methods must be implemented:
forward
,predict
(used for supervised evaluation) andcompute_loss
(used for training).- Execution flow when calling launch:
self.before_fit (nothing by default)
- self._fit (executes all the training/intermittent evaluation by default)
- for self.optim[“epochs”] epochs:
self.before_fit_epoch (setup in train mode)
- self._fit_epoch (one training epoch by default)
- loop over mini-batches
self.before_fit_step (moves data to device)
self._fit_step (optimization step)
self.after_fit_step (perf monitoring and teacher update)
self.after_fit_epoch (nothing by default)
- self._evaluate (if asked, looping over all non-train datasets)
self.before_eval (setup in eval mode)
- loop over mini-batches
self.before_eval_step (moves data to device)
self._eval_step (computes eval metric)
self.after_eval_step (nothing by default)
self.after_eval (nothing by default)
Save intermittent checkpoint if asked by user config
Save final checkpoint if asked by user config
self.after_fit (evaluates by default)
- Parameters:
data (dict) – Names and construction of the dataloaders with their transform pipelines. The dataset named train is used for training. Any other dataset is used for validation.
module (dict) – Names and definition of the modules (neural networks). See
stable_ssl.modules
for examples of available modules.hardware (dict) – Hardware parameters. See
stable_ssl.config.HardwareConfig
for the full list of parameters and their defaults.optim (dict) – Optimization parameters. See
stable_ssl.config.OptimConfig
for the full list of parameters and their defaults.logger (dict) – Logging and checkpointing parameters. See
stable_ssl.config.LoggerConfig
for the full list of parameters and their defaults.loss (dict, optional) – Loss function used in the final criterion to be minimized. See
stable_ssl.losses
for examples. Defaults to None.**kwargs – Additional arguments to be set as attributes of the class.
- after_eval_step()[source]#
Handle post-step tasks after an evaluation step (currently does nothing).
- before_eval_step()[source]#
Prepare batch for evaluation step by moving it to the appropriate device.
- checkpoint() DelayedSubmission [source]#
Create a checkpoint of the current state of the model.
This method is called asynchronously when the SLURM manager sends a preemption signal. It is invoked with the same arguments as the __call__ method. At this point, self represents the current state of the model.
- Returns:
submitit.helpers.DelayedSubmission – representing the requeued task with the current model state.
- Return type:
A delayed submission object
- abstract compute_loss()[source]#
Calculate the global loss to be minimized during training.
Compute the total loss that the model aims to minimize. Implementations can utilize the
loss
function provided during the trainer’s initialization to calculate loss based on the current batch.Note that it can return a list or dictionary of losses. The various losses are logged independently and summed to compute the final loss.
- See Also:
stable_ssl.trainers
for concrete examples of implementations.
- launch()[source]#
Execute the core training and evaluation routine.
This method runs the training and evaluation process, with a customizable boilerplate structure.
The default flow includes: - Running evaluation and cleanup if no “train” dataset is found in self.data. - Otherwise performing pre-training, training, and post-training tasks.
Exceptions#
- BreakAllEpochs
Raised if the training is interrupted by the user.
- abstract predict()[source]#
Generate model predictions for evaluation purposes.
Supervised and Self-Supervised models are typically evaluated using predictions over discrete labels. This method should return the output of this classification used for evaluation.
In SSL, this typically involves using a classifier head on top of the backbone, thus turning the SSL model into a supervised model for evaluation.
- See Also:
stable_ssl.trainers
for concrete examples of implementations.