JointEmbeddingTrainer#

class stable_ssl.trainers.JointEmbeddingTrainer(data, module, hardware, optim, logger, loss=None, **kwargs)[source]#

Bases: BaseTrainer

Base class for training a joint-embedding SSL model.

compute_loss()[source]#

Compute final loss as sum of SSL loss and classifier losses.

compute_loss_classifiers(representations, embeddings, labels)[source]#

Compute the classifier loss for both backbone and projector.

forward(*args, **kwargs)[source]#

Forward pass. By default, it simply calls the ‘backbone’ module.

predict()[source]#

Call the backbone classifier on the forward pass of current batch.