BarlowTwinsLoss

BarlowTwinsLoss#

class stable_ssl.losses.BarlowTwinsLoss(lambd: float = 0.005)[source]#

Bases: Module

SSL objective used in Barlow Twins [Zbontar et al., 2021].

Parameters:

lambd (float, optional) – The weight of the off-diagonal terms in the loss. Default is 5e-3.

forward(z_i, z_j)[source]#

Compute the loss of the Barlow Twins model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float