SupervisedTrainer#

class stable_ssl.trainers.SupervisedTrainer(data, module, hardware, optim, logger, loss=None, **kwargs)[source]#

Bases: BaseTrainer

Base class for training a supervised model.

compute_loss()[source]#

Compute the loss of the model using the loss provided in the config.

forward(*args, **kwargs)[source]#

Forward pass. By default, it simply calls the ‘backbone’ module.

predict()[source]#

Call the forward pass of current batch.