NegativeCosineSimilarity#
- class stable_ssl.losses.NegativeCosineSimilarity(*args, **kwargs)[source]#
Bases:
Module
Negative cosine similarity objective.
This objective is used for instance in BYOL [Grill et al., 2020] or SimSiam [Chen and He, 2021].
- forward(z_i, z_j)[source]#
Compute the loss of the BYOL model.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed loss.
- Return type: