NegativeCosineSimilarity#

class stable_ssl.losses.NegativeCosineSimilarity(*args, **kwargs)[source]#

Bases: Module

Negative cosine similarity objective.

This objective is used for instance in BYOL [Grill et al., 2020] or SimSiam [Chen and He, 2021].

forward(z_i, z_j)[source]#

Compute the loss of the BYOL model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float