SimCLR#
- class stable_ssl.SimCLR(config, *args, **kwargs)[source]#
Bases:
JETrainer
SimCLR model from [CKNH20].
Reference#
[CKNH20]Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A Simple Framework for Contrastive Learning of Visual Representations. In International Conference on Machine Learning (pp. 1597-1607). PMLR.
- compute_ssl_loss(h_i, h_j)[source]#
Compute the contrastive loss for SimCLR.
- Parameters:
h_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
h_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed contrastive loss.
- Return type: