NTXEntLoss#
- class stable_pretraining.losses.NTXEntLoss(temperature: float = 0.5)[source]#
Bases:
ContrastiveLoss
Normalized temperature-scaled cross entropy loss.
Introduced in the SimCLR paper [Chen et al., 2020]. Also used in MoCo [He et al., 2020].
- Parameters:
temperature (float, optional) – The temperature scaling factor. Default is 0.5.
- forward(z_i: Tensor, z_j: Tensor) Tensor [source]#
Compute the NT-Xent loss.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed contrastive loss.
- Return type: