JointEmbeddingPredictiveTrainer#
- class stable_ssl.trainers.JointEmbeddingPredictiveTrainer(data, module, hardware, optim, logger, loss=None, **kwargs)[source]#
Bases:
BaseTrainer
Base class for training a joint-embedding predictive architecture.
- compute_loss()[source]#
Compute the final loss as the L1 distance between the predicted and target latents.