stable_pretraining.losses package#

Submodules#

stable_pretraining.losses.dino module#

DINO self-distillation losses.

This module contains losses for DINO-style self-distillation including: - DINOLoss: CLS token distillation - iBOTPatchLoss: Masked patch prediction

Reference: DINOv2/v3 papers and facebookresearch/dinov3

class stable_pretraining.losses.dino.DINOv1Loss(temperature_student: float = 0.1, center_momentum: float = 0.9)[source]#

Bases: Module

DINOv1 loss for self-distillation with cross-entropy [Caron et al., 2021].

This loss computes cross-entropy between teacher and student logits after applying temperature scaling and normalization. The teacher uses either classical centering or Sinkhorn-Knopp normalization to prevent mode collapse.

Usage:

```python dino_loss = DINOv1Loss()

# Get logits from prototype layer student_logits = prototype_layer(student_embeddings) # [n_views, B, out_dim] teacher_logits = prototype_layer(teacher_embeddings) # [n_views, B, out_dim]

# Approach 1: Classical centering (recommended, faster) teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp=0.04) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Queue async center update

# Approach 2: Sinkhorn-Knopp (more principled, slower, no centering needed) n_views, batch_size, _ = teacher_logits.shape num_samples = n_views * batch_size # Total samples across views teacher_probs = dino_loss.sinkhorn_knopp_teacher(

teacher_logits, temp=0.04, num_samples=num_samples

) loss = dino_loss(student_logits, teacher_probs) # No update_center() needed for Sinkhorn-Knopp! ```

Parameters:
  • temperature_student (float) – Temperature for student softmax. Default is 0.1.

  • center_momentum (float) – EMA momentum for center update. Default is 0.9.

apply_center_update()[source]#

Apply the queued center update with EMA.

FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.

Waits for async all-reduce to complete and updates self.center with EMA. Automatically called by softmax_center_teacher() if update_centers=True.

forward(student_logits: Tensor, teacher_probs: Tensor) Tensor[source]#

Compute DINO cross-entropy loss.

This is a pure loss computation with no side effects (no centering, no updates). Teacher probabilities should be pre-processed with softmax_center_teacher() or sinkhorn_knopp_teacher(). Center updates should be done separately with update_center().

Parameters:
  • student_logits – Student logits [n_views, batch_size, out_dim]

  • teacher_probs – Teacher probabilities (already normalized) [n_views, batch_size, out_dim]

Returns:

Scalar DINO loss value (cross-entropy averaged over view pairs, excluding diagonal)

Shape:
  • student_logits: (S, B, K) where S = student views, B = batch size, K = out_dim

  • teacher_probs: (T, B, K) where T = teacher views

  • output: scalar

sinkhorn_knopp_teacher(teacher_logits, teacher_temp, num_samples=None, n_iterations=3)[source]#

Apply Sinkhorn-Knopp optimal transport normalization to teacher logits.

FOR SINKHORN-KNOPP APPROACH ONLY. DOES NOT USE CENTER.

This method applies sinkhorn-knopp to enforce exact uniform distribution across prototypes without using centering. More principled than centering but more expensive. Used in SwAV and DINOv3 for better theoretical guarantees.

Note: When using Sinkhorn-Knopp, you do NOT need to call update_center() or apply_center_update() since centering is not used.

Parameters:
  • teacher_logits – Teacher logits [*, out_dim]. Can be any shape as long as last dim is out_dim. Common shapes: [batch, out_dim] or [n_views, batch, out_dim]

  • teacher_temp – Temperature for softmax

  • num_samples – Total number of samples across all GPUs (int or tensor). If None, inferred from shape assuming [batch, out_dim] format. For multi-view [n_views, batch, out_dim], pass n_views * batch explicitly.

  • n_iterations – Number of Sinkhorn iterations (default: 3)

Returns:

Teacher probabilities [same shape as input] with uniform prototype distribution

softmax_center_teacher(teacher_logits, teacher_temp, update_centers=True)[source]#

Apply classical centering and sharpening to teacher logits.

FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.

This method subtracts the center (EMA of batch means) from teacher logits before applying softmax. This prevents mode collapse by ensuring balanced prototype usage.

Parameters:
  • teacher_logits – Teacher logits [*, out_dim]

  • teacher_temp – Temperature for teacher softmax

  • update_centers – Whether to apply queued center update before centering

Returns:

Teacher probabilities after centering [*, out_dim]

update_center(teacher_output)[source]#

Queue async center update from teacher logits.

FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.

Starts an asynchronous all-reduce for distributed training. The update is applied later when softmax_center_teacher() is called with update_centers=True. This allows the all-reduce to overlap with backward pass for efficiency.

Typical usage:

teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Start async update # … backward pass happens here, overlapping with all-reduce … # Next iteration: softmax_center_teacher() will call apply_center_update()

Parameters:

teacher_output – Teacher logits [n_views, batch_size, out_dim]

class stable_pretraining.losses.dino.DINOv2Loss(dino_loss_weight: float = 1.0, ibot_loss_weight: float = 1.0, temperature_student: float = 0.1, center_momentum: float = 0.9, student_temp: float = 0.1)[source]#

Bases: Module

DINOv2 loss combining CLS token and masked patch losses.

DINOv2 combines two losses: - DINOv1Loss: CLS token distillation (global views) - uses Sinkhorn-Knopp - iBOTPatchLoss: Masked patch prediction - uses Sinkhorn-Knopp

Both losses use Sinkhorn-Knopp normalization in DINOv2.

Parameters:
  • dino_loss_weight (float) – Weight for CLS token loss. Default is 1.0.

  • ibot_loss_weight (float) – Weight for iBOT patch loss. Default is 1.0.

  • temperature_student (float) – Temperature for student softmax in DINO. Default is 0.1.

  • center_momentum (float) – EMA momentum for DINO centering (not used by iBOT). Default is 0.9.

  • student_temp (float) – Temperature for student softmax in iBOT. Default is 0.1.

forward(student_cls_logits: Tensor, teacher_cls_probs: Tensor, student_patch_logits: Tensor = None, teacher_patch_probs: Tensor = None) Tensor[source]#

Compute combined DINOv2 loss.

Parameters:
  • student_cls_logits – Student CLS logits [n_views, batch, out_dim]

  • teacher_cls_probs – Teacher CLS probs [n_views, batch, out_dim]

  • student_patch_logits – Student patch logits [n_masked_total, patch_out_dim] or None

  • teacher_patch_probs – Teacher patch probs [n_masked_total, patch_out_dim] or None

Returns:

Combined weighted loss

stable_pretraining.losses.dino.cross_entropy_loss(t, s, temp)[source]#

Cross-entropy loss function for iBOT.

Computes per-sample cross-entropy: -Σ t[i] * log_softmax(s[i]/temp)

Parameters:
  • t – Teacher predictions (probabilities) [*, D]

  • s – Student predictions (logits) [*, D]

  • temp – Temperature for student softmax

Returns:

Per-sample cross-entropy loss [*] (positive, lower is better)

class stable_pretraining.losses.dino.iBOTPatchLoss(student_temp: float = 0.1)[source]#

Bases: Module

iBOT patch-level prediction loss for masked patch prediction.

This loss computes cross-entropy between teacher and student patch predictions for masked patches only. Uses Sinkhorn-Knopp normalization exclusively (as in DINOv2/v3) to prevent mode collapse.

Parameters:

student_temp (float) – Temperature for student softmax. Default is 0.1.

forward(student_patch_logits, teacher_patch_probs)[source]#

Compute iBOT cross-entropy loss for masked patches.

Parameters:
  • student_patch_logits – Student patch logits [n_masked_total, patch_out_dim]

  • teacher_patch_probs – Teacher probabilities [n_masked_total, patch_out_dim]

Returns:

Scalar iBOT loss value

sinkhorn_knopp_teacher(teacher_patch_tokens, teacher_temp, num_samples=None, n_iterations=3)[source]#

Apply Sinkhorn-Knopp optimal transport normalization to teacher patch logits.

This method applies optimal transport to enforce exact uniform distribution across prototypes. Used exclusively in DINOv2/v3 for iBOT patch loss.

Parameters:
  • teacher_patch_tokens – Teacher patch logits [n_masked, patch_out_dim]

  • teacher_temp – Temperature for softmax

  • num_samples – Total number of masked patches across all GPUs (int or tensor). If None, inferred from shape.

  • n_iterations – Number of Sinkhorn iterations (default: 3)

Returns:

Teacher probabilities [n_masked, patch_out_dim] with uniform prototype distribution

stable_pretraining.losses.joint_embedding module#

Joint embedding SSL losses.

This module contains joint embedding methods that learn to embed different views of the same image close together in representation space. Includes both contrastive (NTXentLoss) and non-contrastive (BYOL, VICReg, Barlow Twins) methods.

class stable_pretraining.losses.joint_embedding.BYOLLoss(*args, **kwargs)[source]#

Bases: Module

Normalized MSE objective used in BYOL [Grill et al., 2020].

Computes the mean squared error between L2-normalized online predictions and L2-normalized target projections.

forward(online_pred: Tensor, target_proj: Tensor) Tensor[source]#

Compute BYOL loss.

Parameters:
  • online_pred – Predictions from the online network predictor.

  • target_proj – Projections from the target network (no gradient).

Returns:

Scalar loss value.

Return type:

torch.Tensor

class stable_pretraining.losses.joint_embedding.BarlowTwinsLoss(lambd: float = 0.005)[source]#

Bases: Module

SSL objective used in Barlow Twins [Zbontar et al., 2021].

Parameters:

lambd (float, optional) – The weight of the off-diagonal terms in the loss. Default is 5e-3.

forward(z_i, z_j)[source]#

Compute the loss of the Barlow Twins model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float

class stable_pretraining.losses.joint_embedding.InfoNCELoss(temperature: float = 0.07)[source]#

Bases: Module

InfoNCE contrastive loss (one-directional).

This module computes the cross-entropy loss between anchor embeddings and a set of candidate embeddings, given the ground-truth targets. It forms the core mathematical operation for losses like those in CLIP and SimCLR.

Parameters:

temperature (float, optional) – The temperature scaling factor. Default is 0.07.

forward(anchors: Tensor, candidates: Tensor, targets: Tensor, mask: Tensor | None = None, logit_scale: Tensor | float | None = None) Tensor[source]#

Computes the contrastive loss.

Parameters:
  • anchors (torch.Tensor) – The primary set of embeddings (queries) of shape [N, D].

  • candidates (torch.Tensor) – The set of embeddings to contrast against (keys) of shape [M, D].

  • targets (torch.Tensor) – A 1D tensor of ground-truth indices of shape [N], where targets[i] is the index of the positive candidate for anchors[i].

  • mask (torch.Tensor, optional) – A boolean mask of shape [N, M] to exclude certain anchor-candidate pairs from the loss calculation. Values set to True will be ignored.

  • logit_scale (torch.Tensor | float, optional) – The temperature scaling factor. Default is self.temperature.

Returns:

A scalar loss value.

Return type:

torch.Tensor

class stable_pretraining.losses.joint_embedding.NTXEntLoss(temperature: float = 0.5)[source]#

Bases: InfoNCELoss

Normalized temperature-scaled cross entropy loss.

Introduced in the SimCLR paper [Chen et al., 2020]. Also used in MoCo [He et al., 2020].

Parameters:

temperature (float, optional) – The temperature scaling factor. Default is 0.5.

forward(z_i: Tensor, z_j: Tensor) Tensor[source]#

Compute the NT-Xent loss.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed contrastive loss.

Return type:

float

class stable_pretraining.losses.joint_embedding.SwAVLoss(temperature: float = 0.1, sinkhorn_iterations: int = 3, epsilon: float = 0.05)[source]#

Bases: Module

Computes the SwAV loss, optionally using a feature queue.

This loss function contains the core components of the SwAV algorithm, including the Sinkhorn-Knopp algorithm for online clustering and the swapped-prediction contrastive task.

Parameters:
  • temperature (float, optional) – The temperature scaling factor for the softmax in the swapped prediction task. Default is 0.1.

  • sinkhorn_iterations (int, optional) – The number of iterations for the Sinkhorn-Knopp algorithm. Default is 3.

  • epsilon (float, optional) – A small value for numerical stability in the Sinkhorn-Knopp algorithm. Default is 0.05.

Note

Introduced in the SwAV paper [Caron et al., 2020].

forward(proj1, proj2, prototypes, queue_feats=None)[source]#

Compute the SwAV loss.

Args: proj1 (torch.Tensor): Raw projections of the first view. proj2 (torch.Tensor): Raw projections of the second view. prototypes (torch.nn.Module): The prototype vectors. queue_feats (torch.Tensor, optional): Raw features from the queue.

sinkhorn(scores)[source]#

Applies the Sinkhorn-Knopp algorithm.

swapped_prediction(scores, q)[source]#

Computes the cross-entropy loss for the swapped prediction task.

class stable_pretraining.losses.joint_embedding.VICRegLoss(sim_coeff: float = 25, std_coeff: float = 25, cov_coeff: float = 1, epsilon: float = 0.0001)[source]#

Bases: Module

SSL objective used in VICReg [Bardes et al., 2021].

Parameters:
  • sim_coeff (float, optional) – The weight of the similarity loss (attractive term). Default is 25.

  • std_coeff (float, optional) – The weight of the standard deviation loss. Default is 25.

  • cov_coeff (float, optional) – The weight of the covariance loss. Default is 1.

  • epsilon (float, optional) – Small value to avoid division by zero. Default is 1e-4.

forward(z_i, z_j)[source]#

Compute the loss of the VICReg model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float

stable_pretraining.losses.multimodal module#

Multimodal SSL losses.

This module contains losses for multimodal self-supervised learning, particularly for image-text contrastive learning like CLIP.

class stable_pretraining.losses.multimodal.CLIPLoss(temperature: float = 0.07)[source]#

Bases: InfoNCELoss

CLIP loss (symmetric bidirectional InfoNCE).

As used in CLIP [Radford et al., 2021]. Computes symmetric cross-entropy over image-text and text-image logits.

Parameters:

temperature (float, optional) – Softmax temperature. Default is 0.07. (If you use a learnable logit_scale in your model, pass it to forward(…) and this temperature will be ignored.)

forward(feats_i: Tensor, feats_j: Tensor, logit_scale: Tensor | float | None = None) Tensor[source]#

Computes the contrastive loss.

Parameters:
  • anchors (torch.Tensor) – The primary set of embeddings (queries) of shape [N, D].

  • candidates (torch.Tensor) – The set of embeddings to contrast against (keys) of shape [M, D].

  • targets (torch.Tensor) – A 1D tensor of ground-truth indices of shape [N], where targets[i] is the index of the positive candidate for anchors[i].

  • mask (torch.Tensor, optional) – A boolean mask of shape [N, M] to exclude certain anchor-candidate pairs from the loss calculation. Values set to True will be ignored.

  • logit_scale (torch.Tensor | float, optional) – The temperature scaling factor. Default is self.temperature.

Returns:

A scalar loss value.

Return type:

torch.Tensor

stable_pretraining.losses.reconstruction module#

Reconstruction-based SSL losses.

This module contains reconstruction-based self-supervised learning losses such as Masked Autoencoder (MAE).

stable_pretraining.losses.reconstruction.mae(target, pred, mask, norm_pix_loss=False)[source]#

Compute masked autoencoder loss.

Parameters:
  • target – [N, L, p*p*3] target images

  • pred – [N, L, p*p*3] predicted images

  • mask – [N, L], 0 is keep, 1 is remove

  • norm_pix_loss – whether to normalize pixels

Returns:

mean loss value

Return type:

loss

stable_pretraining.losses.utils module#

Utilities for SSL losses.

This module provides helper functions and utilities used by various SSL losses, such as Sinkhorn-Knopp optimal transport algorithm for DINO and iBOT.

class stable_pretraining.losses.utils.NegativeCosineSimilarity(*args, **kwargs)[source]#

Bases: Module

Negative cosine similarity objective.

This objective is used for instance in BYOL [Grill et al., 2020] or SimSiam [Chen and He, 2021].

forward(z_i, z_j)[source]#

Compute the loss of the BYOL model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float

stable_pretraining.losses.utils.off_diagonal(x)[source]#

Return a flattened view of the off-diagonal elements of a square matrix.

stable_pretraining.losses.utils.sinkhorn_knopp(teacher_output: Tensor, teacher_temp: float, num_samples: int | Tensor, n_iterations: int = 3) Tensor[source]#

Sinkhorn-Knopp algorithm for optimal transport normalization.

This is an alternative to simple centering used in DINO/iBOT losses. It performs optimal transport to assign samples to prototypes, ensuring a more uniform distribution across prototypes.

Reference: DINOv3 implementation facebookresearch/dinov3

Parameters:
  • teacher_output – Teacher predictions [batch, prototypes] or [n_samples, prototypes]

  • teacher_temp – Temperature for softmax

  • num_samples – Number of samples to assign. Can be: - int: Fixed batch size (e.g., batch_size * world_size for DINO) - torch.Tensor: Variable count (e.g., n_masked_patches for iBOT)

  • n_iterations – Number of Sinkhorn iterations (default: 3)

Returns:

Normalized probabilities [batch, prototypes] summing to 1 over prototypes

Examples

# DINO CLS token loss (fixed batch size) Q = sinkhorn_knopp(teacher_cls_output, temp=0.04,

num_samples=batch_size * world_size)

# iBOT patch loss (variable number of masked patches) Q = sinkhorn_knopp(teacher_patch_output, temp=0.04,

num_samples=n_masked_patches_tensor)

Module contents#

SSL Losses.

This module provides various self-supervised learning loss functions organized by category: - DINO losses: Self-distillation methods (DINOLoss, iBOTPatchLoss) - Joint embedding losses: Contrastive and non-contrastive methods (BYOL, VICReg, Barlow Twins, SimCLR) - Reconstruction losses: Masked prediction methods (MAE) - Utilities: Helper functions (sinkhorn_knopp, off_diagonal, NegativeCosineSimilarity)

class stable_pretraining.losses.BYOLLoss(*args, **kwargs)[source]#

Bases: Module

Normalized MSE objective used in BYOL [Grill et al., 2020].

Computes the mean squared error between L2-normalized online predictions and L2-normalized target projections.

forward(online_pred: Tensor, target_proj: Tensor) Tensor[source]#

Compute BYOL loss.

Parameters:
  • online_pred – Predictions from the online network predictor.

  • target_proj – Projections from the target network (no gradient).

Returns:

Scalar loss value.

Return type:

torch.Tensor

class stable_pretraining.losses.BarlowTwinsLoss(lambd: float = 0.005)[source]#

Bases: Module

SSL objective used in Barlow Twins [Zbontar et al., 2021].

Parameters:

lambd (float, optional) – The weight of the off-diagonal terms in the loss. Default is 5e-3.

forward(z_i, z_j)[source]#

Compute the loss of the Barlow Twins model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float

class stable_pretraining.losses.CLIPLoss(temperature: float = 0.07)[source]#

Bases: InfoNCELoss

CLIP loss (symmetric bidirectional InfoNCE).

As used in CLIP [Radford et al., 2021]. Computes symmetric cross-entropy over image-text and text-image logits.

Parameters:

temperature (float, optional) – Softmax temperature. Default is 0.07. (If you use a learnable logit_scale in your model, pass it to forward(…) and this temperature will be ignored.)

forward(feats_i: Tensor, feats_j: Tensor, logit_scale: Tensor | float | None = None) Tensor[source]#

Computes the contrastive loss.

Parameters:
  • anchors (torch.Tensor) – The primary set of embeddings (queries) of shape [N, D].

  • candidates (torch.Tensor) – The set of embeddings to contrast against (keys) of shape [M, D].

  • targets (torch.Tensor) – A 1D tensor of ground-truth indices of shape [N], where targets[i] is the index of the positive candidate for anchors[i].

  • mask (torch.Tensor, optional) – A boolean mask of shape [N, M] to exclude certain anchor-candidate pairs from the loss calculation. Values set to True will be ignored.

  • logit_scale (torch.Tensor | float, optional) – The temperature scaling factor. Default is self.temperature.

Returns:

A scalar loss value.

Return type:

torch.Tensor

class stable_pretraining.losses.DINOv1Loss(temperature_student: float = 0.1, center_momentum: float = 0.9)[source]#

Bases: Module

DINOv1 loss for self-distillation with cross-entropy [Caron et al., 2021].

This loss computes cross-entropy between teacher and student logits after applying temperature scaling and normalization. The teacher uses either classical centering or Sinkhorn-Knopp normalization to prevent mode collapse.

Usage:

```python dino_loss = DINOv1Loss()

# Get logits from prototype layer student_logits = prototype_layer(student_embeddings) # [n_views, B, out_dim] teacher_logits = prototype_layer(teacher_embeddings) # [n_views, B, out_dim]

# Approach 1: Classical centering (recommended, faster) teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp=0.04) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Queue async center update

# Approach 2: Sinkhorn-Knopp (more principled, slower, no centering needed) n_views, batch_size, _ = teacher_logits.shape num_samples = n_views * batch_size # Total samples across views teacher_probs = dino_loss.sinkhorn_knopp_teacher(

teacher_logits, temp=0.04, num_samples=num_samples

) loss = dino_loss(student_logits, teacher_probs) # No update_center() needed for Sinkhorn-Knopp! ```

Parameters:
  • temperature_student (float) – Temperature for student softmax. Default is 0.1.

  • center_momentum (float) – EMA momentum for center update. Default is 0.9.

apply_center_update()[source]#

Apply the queued center update with EMA.

FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.

Waits for async all-reduce to complete and updates self.center with EMA. Automatically called by softmax_center_teacher() if update_centers=True.

forward(student_logits: Tensor, teacher_probs: Tensor) Tensor[source]#

Compute DINO cross-entropy loss.

This is a pure loss computation with no side effects (no centering, no updates). Teacher probabilities should be pre-processed with softmax_center_teacher() or sinkhorn_knopp_teacher(). Center updates should be done separately with update_center().

Parameters:
  • student_logits – Student logits [n_views, batch_size, out_dim]

  • teacher_probs – Teacher probabilities (already normalized) [n_views, batch_size, out_dim]

Returns:

Scalar DINO loss value (cross-entropy averaged over view pairs, excluding diagonal)

Shape:
  • student_logits: (S, B, K) where S = student views, B = batch size, K = out_dim

  • teacher_probs: (T, B, K) where T = teacher views

  • output: scalar

sinkhorn_knopp_teacher(teacher_logits, teacher_temp, num_samples=None, n_iterations=3)[source]#

Apply Sinkhorn-Knopp optimal transport normalization to teacher logits.

FOR SINKHORN-KNOPP APPROACH ONLY. DOES NOT USE CENTER.

This method applies sinkhorn-knopp to enforce exact uniform distribution across prototypes without using centering. More principled than centering but more expensive. Used in SwAV and DINOv3 for better theoretical guarantees.

Note: When using Sinkhorn-Knopp, you do NOT need to call update_center() or apply_center_update() since centering is not used.

Parameters:
  • teacher_logits – Teacher logits [*, out_dim]. Can be any shape as long as last dim is out_dim. Common shapes: [batch, out_dim] or [n_views, batch, out_dim]

  • teacher_temp – Temperature for softmax

  • num_samples – Total number of samples across all GPUs (int or tensor). If None, inferred from shape assuming [batch, out_dim] format. For multi-view [n_views, batch, out_dim], pass n_views * batch explicitly.

  • n_iterations – Number of Sinkhorn iterations (default: 3)

Returns:

Teacher probabilities [same shape as input] with uniform prototype distribution

softmax_center_teacher(teacher_logits, teacher_temp, update_centers=True)[source]#

Apply classical centering and sharpening to teacher logits.

FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.

This method subtracts the center (EMA of batch means) from teacher logits before applying softmax. This prevents mode collapse by ensuring balanced prototype usage.

Parameters:
  • teacher_logits – Teacher logits [*, out_dim]

  • teacher_temp – Temperature for teacher softmax

  • update_centers – Whether to apply queued center update before centering

Returns:

Teacher probabilities after centering [*, out_dim]

update_center(teacher_output)[source]#

Queue async center update from teacher logits.

FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.

Starts an asynchronous all-reduce for distributed training. The update is applied later when softmax_center_teacher() is called with update_centers=True. This allows the all-reduce to overlap with backward pass for efficiency.

Typical usage:

teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Start async update # … backward pass happens here, overlapping with all-reduce … # Next iteration: softmax_center_teacher() will call apply_center_update()

Parameters:

teacher_output – Teacher logits [n_views, batch_size, out_dim]

class stable_pretraining.losses.DINOv2Loss(dino_loss_weight: float = 1.0, ibot_loss_weight: float = 1.0, temperature_student: float = 0.1, center_momentum: float = 0.9, student_temp: float = 0.1)[source]#

Bases: Module

DINOv2 loss combining CLS token and masked patch losses.

DINOv2 combines two losses: - DINOv1Loss: CLS token distillation (global views) - uses Sinkhorn-Knopp - iBOTPatchLoss: Masked patch prediction - uses Sinkhorn-Knopp

Both losses use Sinkhorn-Knopp normalization in DINOv2.

Parameters:
  • dino_loss_weight (float) – Weight for CLS token loss. Default is 1.0.

  • ibot_loss_weight (float) – Weight for iBOT patch loss. Default is 1.0.

  • temperature_student (float) – Temperature for student softmax in DINO. Default is 0.1.

  • center_momentum (float) – EMA momentum for DINO centering (not used by iBOT). Default is 0.9.

  • student_temp (float) – Temperature for student softmax in iBOT. Default is 0.1.

forward(student_cls_logits: Tensor, teacher_cls_probs: Tensor, student_patch_logits: Tensor = None, teacher_patch_probs: Tensor = None) Tensor[source]#

Compute combined DINOv2 loss.

Parameters:
  • student_cls_logits – Student CLS logits [n_views, batch, out_dim]

  • teacher_cls_probs – Teacher CLS probs [n_views, batch, out_dim]

  • student_patch_logits – Student patch logits [n_masked_total, patch_out_dim] or None

  • teacher_patch_probs – Teacher patch probs [n_masked_total, patch_out_dim] or None

Returns:

Combined weighted loss

class stable_pretraining.losses.NTXEntLoss(temperature: float = 0.5)[source]#

Bases: InfoNCELoss

Normalized temperature-scaled cross entropy loss.

Introduced in the SimCLR paper [Chen et al., 2020]. Also used in MoCo [He et al., 2020].

Parameters:

temperature (float, optional) – The temperature scaling factor. Default is 0.5.

forward(z_i: Tensor, z_j: Tensor) Tensor[source]#

Compute the NT-Xent loss.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed contrastive loss.

Return type:

float

class stable_pretraining.losses.NegativeCosineSimilarity(*args, **kwargs)[source]#

Bases: Module

Negative cosine similarity objective.

This objective is used for instance in BYOL [Grill et al., 2020] or SimSiam [Chen and He, 2021].

forward(z_i, z_j)[source]#

Compute the loss of the BYOL model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float

class stable_pretraining.losses.VICRegLoss(sim_coeff: float = 25, std_coeff: float = 25, cov_coeff: float = 1, epsilon: float = 0.0001)[source]#

Bases: Module

SSL objective used in VICReg [Bardes et al., 2021].

Parameters:
  • sim_coeff (float, optional) – The weight of the similarity loss (attractive term). Default is 25.

  • std_coeff (float, optional) – The weight of the standard deviation loss. Default is 25.

  • cov_coeff (float, optional) – The weight of the covariance loss. Default is 1.

  • epsilon (float, optional) – Small value to avoid division by zero. Default is 1e-4.

forward(z_i, z_j)[source]#

Compute the loss of the VICReg model.

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed loss.

Return type:

float

class stable_pretraining.losses.iBOTPatchLoss(student_temp: float = 0.1)[source]#

Bases: Module

iBOT patch-level prediction loss for masked patch prediction.

This loss computes cross-entropy between teacher and student patch predictions for masked patches only. Uses Sinkhorn-Knopp normalization exclusively (as in DINOv2/v3) to prevent mode collapse.

Parameters:

student_temp (float) – Temperature for student softmax. Default is 0.1.

forward(student_patch_logits, teacher_patch_probs)[source]#

Compute iBOT cross-entropy loss for masked patches.

Parameters:
  • student_patch_logits – Student patch logits [n_masked_total, patch_out_dim]

  • teacher_patch_probs – Teacher probabilities [n_masked_total, patch_out_dim]

Returns:

Scalar iBOT loss value

sinkhorn_knopp_teacher(teacher_patch_tokens, teacher_temp, num_samples=None, n_iterations=3)[source]#

Apply Sinkhorn-Knopp optimal transport normalization to teacher patch logits.

This method applies optimal transport to enforce exact uniform distribution across prototypes. Used exclusively in DINOv2/v3 for iBOT patch loss.

Parameters:
  • teacher_patch_tokens – Teacher patch logits [n_masked, patch_out_dim]

  • teacher_temp – Temperature for softmax

  • num_samples – Total number of masked patches across all GPUs (int or tensor). If None, inferred from shape.

  • n_iterations – Number of Sinkhorn iterations (default: 3)

Returns:

Teacher probabilities [n_masked, patch_out_dim] with uniform prototype distribution

stable_pretraining.losses.mae(target, pred, mask, norm_pix_loss=False)[source]#

Compute masked autoencoder loss.

Parameters:
  • target – [N, L, p*p*3] target images

  • pred – [N, L, p*p*3] predicted images

  • mask – [N, L], 0 is keep, 1 is remove

  • norm_pix_loss – whether to normalize pixels

Returns:

mean loss value

Return type:

loss

stable_pretraining.losses.off_diagonal(x)[source]#

Return a flattened view of the off-diagonal elements of a square matrix.

stable_pretraining.losses.sinkhorn_knopp(teacher_output: Tensor, teacher_temp: float, num_samples: int | Tensor, n_iterations: int = 3) Tensor[source]#

Sinkhorn-Knopp algorithm for optimal transport normalization.

This is an alternative to simple centering used in DINO/iBOT losses. It performs optimal transport to assign samples to prototypes, ensuring a more uniform distribution across prototypes.

Reference: DINOv3 implementation facebookresearch/dinov3

Parameters:
  • teacher_output – Teacher predictions [batch, prototypes] or [n_samples, prototypes]

  • teacher_temp – Temperature for softmax

  • num_samples – Number of samples to assign. Can be: - int: Fixed batch size (e.g., batch_size * world_size for DINO) - torch.Tensor: Variable count (e.g., n_masked_patches for iBOT)

  • n_iterations – Number of Sinkhorn iterations (default: 3)

Returns:

Normalized probabilities [batch, prototypes] summing to 1 over prototypes

Examples

# DINO CLS token loss (fixed batch size) Q = sinkhorn_knopp(teacher_cls_output, temp=0.04,

num_samples=batch_size * world_size)

# iBOT patch loss (variable number of masked patches) Q = sinkhorn_knopp(teacher_patch_output, temp=0.04,

num_samples=n_masked_patches_tensor)