Source code for stable_pretraining.losses.dino

"""DINO self-distillation losses.

This module contains losses for DINO-style self-distillation including:
- DINOLoss: CLS token distillation
- iBOTPatchLoss: Masked patch prediction

Reference: DINOv2/v3 papers and https://github.com/facebookresearch/dinov3
"""

import torch
import torch.nn.functional as F

from .utils import sinkhorn_knopp


[docs] def cross_entropy_loss(t, s, temp): """Cross-entropy loss function for iBOT. Computes per-sample cross-entropy: -Σ t[i] * log_softmax(s[i]/temp) Args: t: Teacher predictions (probabilities) [*, D] s: Student predictions (logits) [*, D] temp: Temperature for student softmax Returns: Per-sample cross-entropy loss [*] (positive, lower is better) """ return -torch.sum(t.float() * F.log_softmax(s.float() / temp, dim=-1), dim=-1)
[docs] class DINOv1Loss(torch.nn.Module): """DINOv1 loss for self-distillation with cross-entropy :cite:`caron2021emerging`. This loss computes cross-entropy between teacher and student logits after applying temperature scaling and normalization. The teacher uses either classical centering or Sinkhorn-Knopp normalization to prevent mode collapse. Usage: ```python dino_loss = DINOv1Loss() # Get logits from prototype layer student_logits = prototype_layer(student_embeddings) # [n_views, B, out_dim] teacher_logits = prototype_layer(teacher_embeddings) # [n_views, B, out_dim] # Approach 1: Classical centering (recommended, faster) teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp=0.04) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Queue async center update # Approach 2: Sinkhorn-Knopp (more principled, slower, no centering needed) n_views, batch_size, _ = teacher_logits.shape num_samples = n_views * batch_size # Total samples across views teacher_probs = dino_loss.sinkhorn_knopp_teacher( teacher_logits, temp=0.04, num_samples=num_samples ) loss = dino_loss(student_logits, teacher_probs) # No update_center() needed for Sinkhorn-Knopp! ``` Args: temperature_student (float): Temperature for student softmax. Default is 0.1. center_momentum (float): EMA momentum for center update. Default is 0.9. """ def __init__( self, temperature_student: float = 0.1, center_momentum: float = 0.9, ): super().__init__() self.temperature_student = temperature_student self.center_momentum = center_momentum self.center = None self.updated = True self.reduce_handle = None self.len_teacher_output = None self.async_batch_center = None
[docs] def forward( self, student_logits: torch.Tensor, teacher_probs: torch.Tensor, ) -> torch.Tensor: """Compute DINO cross-entropy loss. This is a pure loss computation with no side effects (no centering, no updates). Teacher probabilities should be pre-processed with softmax_center_teacher() or sinkhorn_knopp_teacher(). Center updates should be done separately with update_center(). Args: student_logits: Student logits [n_views, batch_size, out_dim] teacher_probs: Teacher probabilities (already normalized) [n_views, batch_size, out_dim] Returns: Scalar DINO loss value (cross-entropy averaged over view pairs, excluding diagonal) Shape: - student_logits: (S, B, K) where S = student views, B = batch size, K = out_dim - teacher_probs: (T, B, K) where T = teacher views - output: scalar """ # Apply temperature-scaled log-softmax to student student_log_probs = F.log_softmax( student_logits.float() / self.temperature_student, dim=-1 ) # Compute cross-entropy matrix: [S, T] # Sum over batch and features, keep view dimensions loss_matrix = -torch.einsum("sbk,tbk->st", student_log_probs, teacher_probs) # Zero out diagonal (same view comparisons) loss_matrix.fill_diagonal_(0) # Average over valid pairs and batch size n_terms = loss_matrix.numel() - loss_matrix.diagonal().numel() batch_size = student_logits.shape[1] return loss_matrix.sum() / (n_terms * batch_size)
[docs] @torch.no_grad() def sinkhorn_knopp_teacher( self, teacher_logits, teacher_temp, num_samples=None, n_iterations=3 ): """Apply Sinkhorn-Knopp optimal transport normalization to teacher logits. **FOR SINKHORN-KNOPP APPROACH ONLY. DOES NOT USE CENTER.** This method applies sinkhorn-knopp to enforce exact uniform distribution across prototypes without using centering. More principled than centering but more expensive. Used in SwAV and DINOv3 for better theoretical guarantees. Note: When using Sinkhorn-Knopp, you do NOT need to call update_center() or apply_center_update() since centering is not used. Args: teacher_logits: Teacher logits [*, out_dim]. Can be any shape as long as last dim is out_dim. Common shapes: [batch, out_dim] or [n_views, batch, out_dim] teacher_temp: Temperature for softmax num_samples: Total number of samples across all GPUs (int or tensor). If None, inferred from shape assuming [batch, out_dim] format. For multi-view [n_views, batch, out_dim], pass n_views * batch explicitly. n_iterations: Number of Sinkhorn iterations (default: 3) Returns: Teacher probabilities [same shape as input] with uniform prototype distribution """ # Infer num_samples if not provided if num_samples is None: # Assume shape is [batch, out_dim] batch_size = teacher_logits.shape[0] if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() num_samples = batch_size * world_size else: num_samples = batch_size # Flatten all dims except last (out_dim) for Sinkhorn-Knopp original_shape = teacher_logits.shape teacher_logits_flat = teacher_logits.view(-1, original_shape[-1]) result = sinkhorn_knopp( teacher_output=teacher_logits_flat, teacher_temp=teacher_temp, num_samples=num_samples, n_iterations=n_iterations, ) # Reshape back to original shape return result.view(original_shape)
[docs] @torch.no_grad() def softmax_center_teacher(self, teacher_logits, teacher_temp, update_centers=True): """Apply classical centering and sharpening to teacher logits. **FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.** This method subtracts the center (EMA of batch means) from teacher logits before applying softmax. This prevents mode collapse by ensuring balanced prototype usage. Args: teacher_logits: Teacher logits [*, out_dim] teacher_temp: Temperature for teacher softmax update_centers: Whether to apply queued center update before centering Returns: Teacher probabilities after centering [*, out_dim] """ if update_centers: self.apply_center_update() if self.center is not None: return F.softmax((teacher_logits - self.center) / teacher_temp, dim=-1) else: return F.softmax(teacher_logits / teacher_temp, dim=-1)
[docs] @torch.no_grad() def update_center(self, teacher_output): """Queue async center update from teacher logits. **FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.** Starts an asynchronous all-reduce for distributed training. The update is applied later when softmax_center_teacher() is called with update_centers=True. This allows the all-reduce to overlap with backward pass for efficiency. Typical usage: teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Start async update # ... backward pass happens here, overlapping with all-reduce ... # Next iteration: softmax_center_teacher() will call apply_center_update() Args: teacher_output: Teacher logits [n_views, batch_size, out_dim] """ # Mark as not updated yet self.updated = False self.len_teacher_output = len(teacher_output) # Compute batch mean self.async_batch_center = torch.sum(teacher_output.mean(1), dim=0, keepdim=True) # Start async all-reduce across GPUs if torch.distributed.is_available() and torch.distributed.is_initialized(): self.reduce_handle = torch.distributed.all_reduce( self.async_batch_center, async_op=True )
[docs] @torch.no_grad() def apply_center_update(self): """Apply the queued center update with EMA. **FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.** Waits for async all-reduce to complete and updates self.center with EMA. Automatically called by softmax_center_teacher() if update_centers=True. """ if self.updated is False: world_size = ( torch.distributed.get_world_size() if torch.distributed.is_available() and torch.distributed.is_initialized() else 1 ) if self.reduce_handle is not None: self.reduce_handle.wait() _t = self.async_batch_center / (self.len_teacher_output * world_size) # Initialize center on first call if self.center is None: self.center = _t.clone() else: self.center = self.center * self.center_momentum + _t * ( 1 - self.center_momentum ) self.updated = True
[docs] class iBOTPatchLoss(torch.nn.Module): """iBOT patch-level prediction loss for masked patch prediction. This loss computes cross-entropy between teacher and student patch predictions for masked patches only. Uses Sinkhorn-Knopp normalization exclusively (as in DINOv2/v3) to prevent mode collapse. Args: student_temp (float): Temperature for student softmax. Default is 0.1. """ def __init__( self, student_temp: float = 0.1, ): super().__init__() self.student_temp = student_temp
[docs] def forward(self, student_patch_logits, teacher_patch_probs): """Compute iBOT cross-entropy loss for masked patches. Args: student_patch_logits: Student patch logits [n_masked_total, patch_out_dim] teacher_patch_probs: Teacher probabilities [n_masked_total, patch_out_dim] Returns: Scalar iBOT loss value """ loss = cross_entropy_loss( teacher_patch_probs, student_patch_logits, self.student_temp ) return loss.mean()
[docs] @torch.no_grad() def sinkhorn_knopp_teacher( self, teacher_patch_tokens, teacher_temp, num_samples=None, n_iterations=3, ): """Apply Sinkhorn-Knopp optimal transport normalization to teacher patch logits. This method applies optimal transport to enforce exact uniform distribution across prototypes. Used exclusively in DINOv2/v3 for iBOT patch loss. Args: teacher_patch_tokens: Teacher patch logits [n_masked, patch_out_dim] teacher_temp: Temperature for softmax num_samples: Total number of masked patches across all GPUs (int or tensor). If None, inferred from shape. n_iterations: Number of Sinkhorn iterations (default: 3) Returns: Teacher probabilities [n_masked, patch_out_dim] with uniform prototype distribution """ return sinkhorn_knopp( teacher_output=teacher_patch_tokens, teacher_temp=teacher_temp, num_samples=num_samples, n_iterations=n_iterations, )
[docs] class DINOv2Loss(torch.nn.Module): """DINOv2 loss combining CLS token and masked patch losses. DINOv2 combines two losses: - DINOv1Loss: CLS token distillation (global views) - uses Sinkhorn-Knopp - iBOTPatchLoss: Masked patch prediction - uses Sinkhorn-Knopp Both losses use Sinkhorn-Knopp normalization in DINOv2. Args: dino_loss_weight (float): Weight for CLS token loss. Default is 1.0. ibot_loss_weight (float): Weight for iBOT patch loss. Default is 1.0. temperature_student (float): Temperature for student softmax in DINO. Default is 0.1. center_momentum (float): EMA momentum for DINO centering (not used by iBOT). Default is 0.9. student_temp (float): Temperature for student softmax in iBOT. Default is 0.1. """ def __init__( self, dino_loss_weight: float = 1.0, ibot_loss_weight: float = 1.0, temperature_student: float = 0.1, center_momentum: float = 0.9, student_temp: float = 0.1, ): super().__init__() self.dino_loss_weight = dino_loss_weight self.ibot_loss_weight = ibot_loss_weight self.dino_loss = DINOv1Loss( temperature_student=temperature_student, center_momentum=center_momentum, ) self.ibot_loss = iBOTPatchLoss( student_temp=student_temp, )
[docs] def forward( self, student_cls_logits: torch.Tensor, teacher_cls_probs: torch.Tensor, student_patch_logits: torch.Tensor = None, teacher_patch_probs: torch.Tensor = None, ) -> torch.Tensor: """Compute combined DINOv2 loss. Args: student_cls_logits: Student CLS logits [n_views, batch, out_dim] teacher_cls_probs: Teacher CLS probs [n_views, batch, out_dim] student_patch_logits: Student patch logits [n_masked_total, patch_out_dim] or None teacher_patch_probs: Teacher patch probs [n_masked_total, patch_out_dim] or None Returns: Combined weighted loss """ dino_loss = self.dino_loss(student_cls_logits, teacher_cls_probs) if student_patch_logits is None or teacher_patch_probs is None: return self.dino_loss_weight * dino_loss ibot_loss = self.ibot_loss(student_patch_logits, teacher_patch_probs) return self.dino_loss_weight * dino_loss + self.ibot_loss_weight * ibot_loss