Source code for stable_ssl.losses

"""SSL losses."""
#
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#         Randall Balestriero <randallbalestriero@gmail.com>
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F

from stable_ssl.utils import all_gather, all_reduce, off_diagonal


[docs] class NTXEntLoss(torch.nn.Module): """Normalized temperature-scaled cross entropy loss. Introduced in the SimCLR paper :cite:`chen2020simple`. Also used in MoCo :cite:`he2020momentum`. Parameters ---------- temperature : float, optional The temperature scaling factor. Default is 0.5. """ def __init__(self, temperature: float = 0.5): super().__init__() self.temperature = temperature
[docs] def forward(self, z_i, z_j): """Compute the NT-Xent loss. Parameters ---------- z_i : torch.Tensor Latent representation of the first augmented view of the batch. z_j : torch.Tensor Latent representation of the second augmented view of the batch. Returns ------- float The computed contrastive loss. """ z_i = all_gather(z_i) z_j = all_gather(z_j) z = torch.cat([z_i, z_j], 0) N = z.size(0) features = F.normalize(z, dim=1) sim = torch.matmul(features, features.T) / self.temperature sim_i_j = torch.diag(sim, N // 2) sim_j_i = torch.diag(sim, -N // 2) positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0) mask = torch.eye(N, dtype=bool).to(z_i.device) negative_samples = sim[~mask].reshape(N, -1) attraction = -positive_samples.mean() repulsion = torch.logsumexp(negative_samples, dim=1).mean() return attraction + repulsion
[docs] class NegativeCosineSimilarity(torch.nn.Module): """Negative cosine similarity objective. This objective is used for instance in BYOL :cite:`grill2020bootstrap` or SimSiam :cite:`chen2021exploring`. """
[docs] def forward(self, z_i, z_j): """Compute the loss of the BYOL model. Parameters ---------- z_i : torch.Tensor Latent representation of the first augmented view of the batch. z_j : torch.Tensor Latent representation of the second augmented view of the batch. Returns ------- float The computed loss. """ sim = torch.nn.CosineSimilarity(dim=1) return -sim(z_i, z_j).mean()
[docs] class VICRegLoss(torch.nn.Module): """SSL objective used in VICReg :cite:`bardes2021vicreg`. Parameters ---------- sim_coeff : float, optional The weight of the similarity loss (attractive term). Default is 25. std_coeff : float, optional The weight of the standard deviation loss. Default is 25. cov_coeff : float, optional The weight of the covariance loss. Default is 1. epsilon : float, optional Small value to avoid division by zero. Default is 1e-4. """ def __init__( self, sim_coeff: float = 25, std_coeff: float = 25, cov_coeff: float = 1, epsilon: float = 1e-4, ): super().__init__() self.sim_coeff = sim_coeff self.std_coeff = std_coeff self.cov_coeff = cov_coeff self.epsilon = epsilon
[docs] def forward(self, z_i, z_j): """Compute the loss of the VICReg model. Parameters ---------- z_i : torch.Tensor Latent representation of the first augmented view of the batch. z_j : torch.Tensor Latent representation of the second augmented view of the batch. Returns ------- float The computed loss. """ repr_loss = F.mse_loss(z_i, z_j) z_i = all_gather(z_i) z_j = all_gather(z_j) z_i = z_i - z_i.mean(dim=0) z_j = z_j - z_j.mean(dim=0) std_i = torch.sqrt(z_i.var(dim=0) + self.epsilon) std_j = torch.sqrt(z_j.var(dim=0) + self.epsilon) std_loss = torch.mean(F.relu(1 - std_i)) / 2 + torch.mean(F.relu(1 - std_j)) / 2 cov_i = (z_i.T @ z_i) / (z_i.size(0) - 1) cov_j = (z_j.T @ z_j) / (z_i.size(0) - 1) cov_loss = off_diagonal(cov_i).pow_(2).sum().div(z_i.size(1)) + off_diagonal( cov_j ).pow_(2).sum().div(z_i.size(1)) loss = ( self.sim_coeff * repr_loss + self.std_coeff * std_loss + self.cov_coeff * cov_loss ) return loss
[docs] class BarlowTwinsLoss(torch.nn.Module): """SSL objective used in Barlow Twins :cite:`zbontar2021barlow`. Parameters ---------- lambd : float, optional The weight of the off-diagonal terms in the loss. Default is 5e-3. """ def __init__(self, lambd: float = 5e-3): super().__init__() self.lambd = lambd self.bn = torch.nn.LazyBatchNorm1d()
[docs] def forward(self, z_i, z_j): """Compute the loss of the Barlow Twins model. Parameters ---------- z_i : torch.Tensor Latent representation of the first augmented view of the batch. z_j : torch.Tensor Latent representation of the second augmented view of the batch. Returns ------- float The computed loss. """ c = self.bn(z_i).T @ self.bn(z_j) # normalize along the batch dimension c = c / z_i.size(0) all_reduce(c) on_diag = (torch.diagonal(c) - 1).pow(2).sum() off_diag = off_diagonal(c).pow(2).sum() loss = on_diag + self.lambd * off_diag return loss