JointEmbeddingPredictiveTrainer#

class stable_ssl.trainers.JointEmbeddingPredictiveTrainer(data, module, hardware, optim, logger, loss=None, **kwargs)[source]#

Bases: BaseTrainer

Base class for training a joint-embedding predictive architecture.

compute_loss()[source]#

Compute the final loss as the L1 distance between the predicted and target latents.

forward(*args, **kwargs)[source]#

Forward pass of the context encoder.

forward_predictor(*args, **kwargs)[source]#

Forward pass of the predictor, that transforms the context latents into the target latents.

forward_target(*args, **kwargs)[source]#

Forward pass of the target encoder.