SimCLR#

class stable_ssl.SimCLR(config, *args, **kwargs)[source]#

Bases: JETrainer

SimCLR model from [CKNH20].

Reference#

[CKNH20]

Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A Simple Framework for Contrastive Learning of Visual Representations. In International Conference on Machine Learning (pp. 1597-1607). PMLR.

compute_ssl_loss(h_i, h_j)[source]#

Compute the contrastive loss for SimCLR.

Parameters:
  • h_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • h_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed contrastive loss.

Return type:

float