Source code for stable_pretraining.losses.utils

"""Utilities for SSL losses.

This module provides helper functions and utilities used by various SSL losses,
such as Sinkhorn-Knopp optimal transport algorithm for DINO and iBOT.
"""

import torch
import torch.distributed as dist
from loguru import logger as logging


def off_diagonal(x):
    """Return a flattened view of the off-diagonal elements of a square matrix."""
    n, m = x.shape
    assert n == m, logging.error("Input tensor must be square.")
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


[docs] class NegativeCosineSimilarity(torch.nn.Module): """Negative cosine similarity objective. This objective is used for instance in BYOL :cite:`grill2020bootstrap` or SimSiam :cite:`chen2021exploring`. """
[docs] def forward(self, z_i, z_j): """Compute the loss of the BYOL model. Args: z_i (torch.Tensor): Latent representation of the first augmented view of the batch. z_j (torch.Tensor): Latent representation of the second augmented view of the batch. Returns: float: The computed loss. """ sim = torch.nn.CosineSimilarity(dim=1) return -sim(z_i, z_j).mean()
@torch.no_grad() def sinkhorn_knopp( teacher_output: torch.Tensor, teacher_temp: float, num_samples: int | torch.Tensor, n_iterations: int = 3, ) -> torch.Tensor: """Sinkhorn-Knopp algorithm for optimal transport normalization. This is an alternative to simple centering used in DINO/iBOT losses. It performs optimal transport to assign samples to prototypes, ensuring a more uniform distribution across prototypes. Reference: DINOv3 implementation https://github.com/facebookresearch/dinov3 Args: teacher_output: Teacher predictions [batch, prototypes] or [n_samples, prototypes] teacher_temp: Temperature for softmax num_samples: Number of samples to assign. Can be: - int: Fixed batch size (e.g., batch_size * world_size for DINO) - torch.Tensor: Variable count (e.g., n_masked_patches for iBOT) n_iterations: Number of Sinkhorn iterations (default: 3) Returns: Normalized probabilities [batch, prototypes] summing to 1 over prototypes Examples: # DINO CLS token loss (fixed batch size) Q = sinkhorn_knopp(teacher_cls_output, temp=0.04, num_samples=batch_size * world_size) # iBOT patch loss (variable number of masked patches) Q = sinkhorn_knopp(teacher_patch_output, temp=0.04, num_samples=n_masked_patches_tensor) """ teacher_output = teacher_output.float() # Q is K-by-B for consistency with paper notations Q = torch.exp(teacher_output / teacher_temp).t() K = Q.shape[0] # number of prototypes # Handle num_samples as tensor or int if isinstance(num_samples, torch.Tensor): num_samples = num_samples.clone() if dist.is_available() and dist.is_initialized(): dist.all_reduce(num_samples) # Make the matrix sum to 1 sum_Q = torch.sum(Q) if dist.is_available() and dist.is_initialized(): dist.all_reduce(sum_Q) Q /= sum_Q # Sinkhorn iterations for _ in range(n_iterations): # Normalize each row: total weight per prototype must be 1/K sum_of_rows = torch.sum(Q, dim=1, keepdim=True) if dist.is_available() and dist.is_initialized(): dist.all_reduce(sum_of_rows) Q /= sum_of_rows Q /= K # Normalize each column: total weight per sample must be 1/num_samples Q /= torch.sum(Q, dim=0, keepdim=True) Q /= num_samples Q *= num_samples # the columns must sum to 1 so that Q is an assignment return Q.t()