VICRegLoss

VICRegLoss#

class stable_ssl.losses.VICRegLoss(sim_coeff: float = 25, std_coeff: float = 25, cov_coeff: float = 1, epsilon: float = 0.0001)[source]#

Bases: Module

SSL objective used in VICReg [Bardes et al., 2021].

Parameters:
  • sim_coeff (float, optional) – The weight of the similarity loss (attractive term). Default is 25.

  • std_coeff (float, optional) – The weight of the standard deviation loss. Default is 25.

  • cov_coeff (float, optional) – The weight of the covariance loss. Default is 1.

  • epsilon (float, optional) – Small value to avoid division by zero. Default is 1e-4.

forward(z_i, z_j)[source]#

Compute the loss of the VICReg model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float