Source code for stable_pretraining.losses

"""SSL losses."""

import torch
import torch.nn.functional as F
from loguru import logger as logging

from .utils import all_gather, all_reduce


def mae(target, pred, mask, norm_pix_loss=False):
    """Compute masked autoencoder loss.

    Args:
        target: [N, L, p*p*3] target images
        pred: [N, L, p*p*3] predicted images
        mask: [N, L], 0 is keep, 1 is remove
        norm_pix_loss: whether to normalize pixels

    Returns:
        loss: mean loss value
    """
    if norm_pix_loss:
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.0e-6) ** 0.5

    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

    loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
    return loss


def off_diagonal(x):
    """Return a flattened view of the off-diagonal elements of a square matrix."""
    n, m = x.shape
    assert n == m, logging.error("Input tensor must be square.")
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


[docs] class NTXEntLoss(torch.nn.Module): """Normalized temperature-scaled cross entropy loss. Introduced in the SimCLR paper :cite:`chen2020simple`. Also used in MoCo :cite:`he2020momentum`. Args: temperature (float, optional): The temperature scaling factor. Default is 0.5. """ def __init__(self, temperature: float = 0.5): super().__init__() self.temperature = temperature
[docs] def forward(self, z_i, z_j): """Compute the NT-Xent loss. Args: z_i (torch.Tensor): Latent representation of the first augmented view of the batch. z_j (torch.Tensor): Latent representation of the second augmented view of the batch. Returns: float: The computed contrastive loss. """ z_i = torch.cat(all_gather(z_i), 0) z_j = torch.cat(all_gather(z_j), 0) z = torch.cat([z_i, z_j], 0) N = z.size(0) features = F.normalize(z, dim=1) sim = torch.matmul(features, features.T) / self.temperature sim_i_j = torch.diag(sim, N // 2) sim_j_i = torch.diag(sim, -N // 2) positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0) mask = torch.eye(N, dtype=bool).to(z_i.device) negative_samples = sim[~mask].reshape(N, -1) attraction = -positive_samples.mean() repulsion = torch.logsumexp(negative_samples, dim=1).mean() return attraction + repulsion
[docs] class NegativeCosineSimilarity(torch.nn.Module): """Negative cosine similarity objective. This objective is used for instance in BYOL :cite:`grill2020bootstrap` or SimSiam :cite:`chen2021exploring`. """
[docs] def forward(self, z_i, z_j): """Compute the loss of the BYOL model. Args: z_i (torch.Tensor): Latent representation of the first augmented view of the batch. z_j (torch.Tensor): Latent representation of the second augmented view of the batch. Returns: float: The computed loss. """ sim = torch.nn.CosineSimilarity(dim=1) return -sim(z_i, z_j).mean()
class BYOLLoss(torch.nn.Module): """Normalized MSE objective used in BYOL :cite:`grill2020bootstrap`. Computes the mean squared error between L2-normalized online predictions and L2-normalized target projections. """ def forward( self, online_pred: torch.Tensor, target_proj: torch.Tensor ) -> torch.Tensor: """Compute BYOL loss. Args: online_pred: Predictions from the online network predictor. target_proj: Projections from the target network (no gradient). Returns: torch.Tensor: Scalar loss value. """ online_pred = F.normalize(online_pred, dim=-1, p=2) target_proj = F.normalize(target_proj, dim=-1, p=2) loss = 2 - 2 * (online_pred * target_proj).sum(dim=-1) return loss.mean()
[docs] class VICRegLoss(torch.nn.Module): """SSL objective used in VICReg :cite:`bardes2021vicreg`. Args: sim_coeff (float, optional): The weight of the similarity loss (attractive term). Default is 25. std_coeff (float, optional): The weight of the standard deviation loss. Default is 25. cov_coeff (float, optional): The weight of the covariance loss. Default is 1. epsilon (float, optional): Small value to avoid division by zero. Default is 1e-4. """ def __init__( self, sim_coeff: float = 25, std_coeff: float = 25, cov_coeff: float = 1, epsilon: float = 1e-4, ): super().__init__() self.sim_coeff = sim_coeff self.std_coeff = std_coeff self.cov_coeff = cov_coeff self.epsilon = epsilon
[docs] def forward(self, z_i, z_j): """Compute the loss of the VICReg model. Args: z_i (torch.Tensor): Latent representation of the first augmented view of the batch. z_j (torch.Tensor): Latent representation of the second augmented view of the batch. Returns: float: The computed loss. """ repr_loss = F.mse_loss(z_i, z_j) z_i = torch.cat(all_gather(z_i), 0) z_j = torch.cat(all_gather(z_j), 0) z_i = z_i - z_i.mean(dim=0) z_j = z_j - z_j.mean(dim=0) std_i = torch.sqrt(z_i.var(dim=0) + self.epsilon) std_j = torch.sqrt(z_j.var(dim=0) + self.epsilon) std_loss = torch.mean(F.relu(1 - std_i)) / 2 + torch.mean(F.relu(1 - std_j)) / 2 cov_i = (z_i.T @ z_i) / (z_i.size(0) - 1) cov_j = (z_j.T @ z_j) / (z_i.size(0) - 1) cov_loss = off_diagonal(cov_i).pow_(2).sum().div(z_i.size(1)) + off_diagonal( cov_j ).pow_(2).sum().div(z_i.size(1)) loss = ( self.sim_coeff * repr_loss + self.std_coeff * std_loss + self.cov_coeff * cov_loss ) return loss
[docs] class BarlowTwinsLoss(torch.nn.Module): """SSL objective used in Barlow Twins :cite:`zbontar2021barlow`. Args: lambd (float, optional): The weight of the off-diagonal terms in the loss. Default is 5e-3. """ def __init__(self, lambd: float = 5e-3): super().__init__() self.lambd = lambd self.bn = torch.nn.LazyBatchNorm1d()
[docs] def forward(self, z_i, z_j): """Compute the loss of the Barlow Twins model. Args: z_i (torch.Tensor): Latent representation of the first augmented view of the batch. z_j (torch.Tensor): Latent representation of the second augmented view of the batch. Returns: float: The computed loss. """ c = self.bn(z_i).T @ self.bn(z_j) # normalize along the batch dimension c = c / z_i.size(0) all_reduce(c) on_diag = (torch.diagonal(c) - 1).pow(2).sum() off_diag = off_diagonal(c).pow(2).sum() loss = on_diag + self.lambd * off_diag return loss