stable_pretraining.losses package#
Submodules#
stable_pretraining.losses.dino module#
DINO self-distillation losses.
This module contains losses for DINO-style self-distillation including: - DINOLoss: CLS token distillation - iBOTPatchLoss: Masked patch prediction
Reference: DINOv2/v3 papers and facebookresearch/dinov3
- class stable_pretraining.losses.dino.DINOv1Loss(temperature_student: float = 0.1, center_momentum: float = 0.9)[source]#
Bases:
ModuleDINOv1 loss for self-distillation with cross-entropy [Caron et al., 2021].
This loss computes cross-entropy between teacher and student logits after applying temperature scaling and normalization. The teacher uses either classical centering or Sinkhorn-Knopp normalization to prevent mode collapse.
- Usage:
```python dino_loss = DINOv1Loss()
# Get logits from prototype layer student_logits = prototype_layer(student_embeddings) # [n_views, B, out_dim] teacher_logits = prototype_layer(teacher_embeddings) # [n_views, B, out_dim]
# Approach 1: Classical centering (recommended, faster) teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp=0.04) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Queue async center update
# Approach 2: Sinkhorn-Knopp (more principled, slower, no centering needed) n_views, batch_size, _ = teacher_logits.shape num_samples = n_views * batch_size # Total samples across views teacher_probs = dino_loss.sinkhorn_knopp_teacher(
teacher_logits, temp=0.04, num_samples=num_samples
) loss = dino_loss(student_logits, teacher_probs) # No update_center() needed for Sinkhorn-Knopp! ```
- Parameters:
- apply_center_update()[source]#
Apply the queued center update with EMA.
FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.
Waits for async all-reduce to complete and updates self.center with EMA. Automatically called by softmax_center_teacher() if update_centers=True.
- forward(student_logits: Tensor, teacher_probs: Tensor) Tensor[source]#
Compute DINO cross-entropy loss.
This is a pure loss computation with no side effects (no centering, no updates). Teacher probabilities should be pre-processed with softmax_center_teacher() or sinkhorn_knopp_teacher(). Center updates should be done separately with update_center().
- Parameters:
student_logits – Student logits [n_views, batch_size, out_dim]
teacher_probs – Teacher probabilities (already normalized) [n_views, batch_size, out_dim]
- Returns:
Scalar DINO loss value (cross-entropy averaged over view pairs, excluding diagonal)
- Shape:
student_logits: (S, B, K) where S = student views, B = batch size, K = out_dim
teacher_probs: (T, B, K) where T = teacher views
output: scalar
- sinkhorn_knopp_teacher(teacher_logits, teacher_temp, num_samples=None, n_iterations=3)[source]#
Apply Sinkhorn-Knopp optimal transport normalization to teacher logits.
FOR SINKHORN-KNOPP APPROACH ONLY. DOES NOT USE CENTER.
This method applies sinkhorn-knopp to enforce exact uniform distribution across prototypes without using centering. More principled than centering but more expensive. Used in SwAV and DINOv3 for better theoretical guarantees.
Note: When using Sinkhorn-Knopp, you do NOT need to call update_center() or apply_center_update() since centering is not used.
- Parameters:
teacher_logits – Teacher logits [*, out_dim]. Can be any shape as long as last dim is out_dim. Common shapes: [batch, out_dim] or [n_views, batch, out_dim]
teacher_temp – Temperature for softmax
num_samples – Total number of samples across all GPUs (int or tensor). If None, inferred from shape assuming [batch, out_dim] format. For multi-view [n_views, batch, out_dim], pass n_views * batch explicitly.
n_iterations – Number of Sinkhorn iterations (default: 3)
- Returns:
Teacher probabilities [same shape as input] with uniform prototype distribution
- softmax_center_teacher(teacher_logits, teacher_temp, update_centers=True)[source]#
Apply classical centering and sharpening to teacher logits.
FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.
This method subtracts the center (EMA of batch means) from teacher logits before applying softmax. This prevents mode collapse by ensuring balanced prototype usage.
- update_center(teacher_output)[source]#
Queue async center update from teacher logits.
FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.
Starts an asynchronous all-reduce for distributed training. The update is applied later when softmax_center_teacher() is called with update_centers=True. This allows the all-reduce to overlap with backward pass for efficiency.
- Typical usage:
teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Start async update # … backward pass happens here, overlapping with all-reduce … # Next iteration: softmax_center_teacher() will call apply_center_update()
- Parameters:
teacher_output – Teacher logits [n_views, batch_size, out_dim]
- class stable_pretraining.losses.dino.DINOv2Loss(dino_loss_weight: float = 1.0, ibot_loss_weight: float = 1.0, temperature_student: float = 0.1, center_momentum: float = 0.9, student_temp: float = 0.1)[source]#
Bases:
ModuleDINOv2 loss combining CLS token and masked patch losses.
DINOv2 combines two losses: - DINOv1Loss: CLS token distillation (global views) - uses Sinkhorn-Knopp - iBOTPatchLoss: Masked patch prediction - uses Sinkhorn-Knopp
Both losses use Sinkhorn-Knopp normalization in DINOv2.
- Parameters:
dino_loss_weight (float) – Weight for CLS token loss. Default is 1.0.
ibot_loss_weight (float) – Weight for iBOT patch loss. Default is 1.0.
temperature_student (float) – Temperature for student softmax in DINO. Default is 0.1.
center_momentum (float) – EMA momentum for DINO centering (not used by iBOT). Default is 0.9.
student_temp (float) – Temperature for student softmax in iBOT. Default is 0.1.
- forward(student_cls_logits: Tensor, teacher_cls_probs: Tensor, student_patch_logits: Tensor = None, teacher_patch_probs: Tensor = None) Tensor[source]#
Compute combined DINOv2 loss.
- Parameters:
student_cls_logits – Student CLS logits [n_views, batch, out_dim]
teacher_cls_probs – Teacher CLS probs [n_views, batch, out_dim]
student_patch_logits – Student patch logits [n_masked_total, patch_out_dim] or None
teacher_patch_probs – Teacher patch probs [n_masked_total, patch_out_dim] or None
- Returns:
Combined weighted loss
- stable_pretraining.losses.dino.cross_entropy_loss(t, s, temp)[source]#
Cross-entropy loss function for iBOT.
Computes per-sample cross-entropy: -Σ t[i] * log_softmax(s[i]/temp)
- class stable_pretraining.losses.dino.iBOTPatchLoss(student_temp: float = 0.1)[source]#
Bases:
ModuleiBOT patch-level prediction loss for masked patch prediction.
This loss computes cross-entropy between teacher and student patch predictions for masked patches only. Uses Sinkhorn-Knopp normalization exclusively (as in DINOv2/v3) to prevent mode collapse.
- Parameters:
student_temp (float) – Temperature for student softmax. Default is 0.1.
- forward(student_patch_logits, teacher_patch_probs)[source]#
Compute iBOT cross-entropy loss for masked patches.
- Parameters:
student_patch_logits – Student patch logits [n_masked_total, patch_out_dim]
teacher_patch_probs – Teacher probabilities [n_masked_total, patch_out_dim]
- Returns:
Scalar iBOT loss value
- sinkhorn_knopp_teacher(teacher_patch_tokens, teacher_temp, num_samples=None, n_iterations=3)[source]#
Apply Sinkhorn-Knopp optimal transport normalization to teacher patch logits.
This method applies optimal transport to enforce exact uniform distribution across prototypes. Used exclusively in DINOv2/v3 for iBOT patch loss.
- Parameters:
teacher_patch_tokens – Teacher patch logits [n_masked, patch_out_dim]
teacher_temp – Temperature for softmax
num_samples – Total number of masked patches across all GPUs (int or tensor). If None, inferred from shape.
n_iterations – Number of Sinkhorn iterations (default: 3)
- Returns:
Teacher probabilities [n_masked, patch_out_dim] with uniform prototype distribution
stable_pretraining.losses.joint_embedding module#
Joint embedding SSL losses.
This module contains joint embedding methods that learn to embed different views of the same image close together in representation space. Includes both contrastive (NTXentLoss) and non-contrastive (BYOL, VICReg, Barlow Twins) methods.
- class stable_pretraining.losses.joint_embedding.BYOLLoss(*args, **kwargs)[source]#
Bases:
ModuleNormalized MSE objective used in BYOL [Grill et al., 2020].
Computes the mean squared error between L2-normalized online predictions and L2-normalized target projections.
- class stable_pretraining.losses.joint_embedding.BarlowTwinsLoss(lambd: float = 0.005)[source]#
Bases:
ModuleSSL objective used in Barlow Twins [Zbontar et al., 2021].
- Parameters:
lambd (float, optional) – The weight of the off-diagonal terms in the loss. Default is 5e-3.
- forward(z_i, z_j)[source]#
Compute the loss of the Barlow Twins model.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed loss.
- Return type:
- class stable_pretraining.losses.joint_embedding.InfoNCELoss(temperature: float = 0.07)[source]#
Bases:
ModuleInfoNCE contrastive loss (one-directional).
This module computes the cross-entropy loss between anchor embeddings and a set of candidate embeddings, given the ground-truth targets. It forms the core mathematical operation for losses like those in CLIP and SimCLR.
- Parameters:
temperature (float, optional) – The temperature scaling factor. Default is 0.07.
- forward(anchors: Tensor, candidates: Tensor, targets: Tensor, mask: Tensor | None = None, logit_scale: Tensor | float | None = None) Tensor[source]#
Computes the contrastive loss.
- Parameters:
anchors (torch.Tensor) – The primary set of embeddings (queries) of shape [N, D].
candidates (torch.Tensor) – The set of embeddings to contrast against (keys) of shape [M, D].
targets (torch.Tensor) – A 1D tensor of ground-truth indices of shape [N], where targets[i] is the index of the positive candidate for anchors[i].
mask (torch.Tensor, optional) – A boolean mask of shape [N, M] to exclude certain anchor-candidate pairs from the loss calculation. Values set to True will be ignored.
logit_scale (torch.Tensor | float, optional) – The temperature scaling factor. Default is self.temperature.
- Returns:
A scalar loss value.
- Return type:
- class stable_pretraining.losses.joint_embedding.NTXEntLoss(temperature: float = 0.5)[source]#
Bases:
InfoNCELossNormalized temperature-scaled cross entropy loss.
Introduced in the SimCLR paper [Chen et al., 2020]. Also used in MoCo [He et al., 2020].
- Parameters:
temperature (float, optional) – The temperature scaling factor. Default is 0.5.
- forward(z_i: Tensor, z_j: Tensor) Tensor[source]#
Compute the NT-Xent loss.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed contrastive loss.
- Return type:
- class stable_pretraining.losses.joint_embedding.SwAVLoss(temperature: float = 0.1, sinkhorn_iterations: int = 3, epsilon: float = 0.05)[source]#
Bases:
ModuleComputes the SwAV loss, optionally using a feature queue.
This loss function contains the core components of the SwAV algorithm, including the Sinkhorn-Knopp algorithm for online clustering and the swapped-prediction contrastive task.
- Parameters:
temperature (float, optional) – The temperature scaling factor for the softmax in the swapped prediction task. Default is 0.1.
sinkhorn_iterations (int, optional) – The number of iterations for the Sinkhorn-Knopp algorithm. Default is 3.
epsilon (float, optional) – A small value for numerical stability in the Sinkhorn-Knopp algorithm. Default is 0.05.
Note
Introduced in the SwAV paper [Caron et al., 2020].
- forward(proj1, proj2, prototypes, queue_feats=None)[source]#
Compute the SwAV loss.
Args: proj1 (torch.Tensor): Raw projections of the first view. proj2 (torch.Tensor): Raw projections of the second view. prototypes (torch.nn.Module): The prototype vectors. queue_feats (torch.Tensor, optional): Raw features from the queue.
- class stable_pretraining.losses.joint_embedding.VICRegLoss(sim_coeff: float = 25, std_coeff: float = 25, cov_coeff: float = 1, epsilon: float = 0.0001)[source]#
Bases:
ModuleSSL objective used in VICReg [Bardes et al., 2021].
- Parameters:
sim_coeff (float, optional) – The weight of the similarity loss (attractive term). Default is 25.
std_coeff (float, optional) – The weight of the standard deviation loss. Default is 25.
cov_coeff (float, optional) – The weight of the covariance loss. Default is 1.
epsilon (float, optional) – Small value to avoid division by zero. Default is 1e-4.
- forward(z_i, z_j)[source]#
Compute the loss of the VICReg model.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed loss.
- Return type:
stable_pretraining.losses.multimodal module#
Multimodal SSL losses.
This module contains losses for multimodal self-supervised learning, particularly for image-text contrastive learning like CLIP.
- class stable_pretraining.losses.multimodal.CLIPLoss(temperature: float = 0.07)[source]#
Bases:
InfoNCELossCLIP loss (symmetric bidirectional InfoNCE).
As used in CLIP [Radford et al., 2021]. Computes symmetric cross-entropy over image-text and text-image logits.
- Parameters:
temperature (float, optional) – Softmax temperature. Default is 0.07. (If you use a learnable logit_scale in your model, pass it to forward(…) and this temperature will be ignored.)
- forward(feats_i: Tensor, feats_j: Tensor, logit_scale: Tensor | float | None = None) Tensor[source]#
Computes the contrastive loss.
- Parameters:
anchors (torch.Tensor) – The primary set of embeddings (queries) of shape [N, D].
candidates (torch.Tensor) – The set of embeddings to contrast against (keys) of shape [M, D].
targets (torch.Tensor) – A 1D tensor of ground-truth indices of shape [N], where targets[i] is the index of the positive candidate for anchors[i].
mask (torch.Tensor, optional) – A boolean mask of shape [N, M] to exclude certain anchor-candidate pairs from the loss calculation. Values set to True will be ignored.
logit_scale (torch.Tensor | float, optional) – The temperature scaling factor. Default is self.temperature.
- Returns:
A scalar loss value.
- Return type:
stable_pretraining.losses.reconstruction module#
Reconstruction-based SSL losses.
This module contains reconstruction-based self-supervised learning losses such as Masked Autoencoder (MAE).
- stable_pretraining.losses.reconstruction.mae(target, pred, mask, norm_pix_loss=False)[source]#
Compute masked autoencoder loss.
- Parameters:
target – [N, L, p*p*3] target images
pred – [N, L, p*p*3] predicted images
mask – [N, L], 0 is keep, 1 is remove
norm_pix_loss – whether to normalize pixels
- Returns:
mean loss value
- Return type:
loss
stable_pretraining.losses.utils module#
Utilities for SSL losses.
This module provides helper functions and utilities used by various SSL losses, such as Sinkhorn-Knopp optimal transport algorithm for DINO and iBOT.
- class stable_pretraining.losses.utils.NegativeCosineSimilarity(*args, **kwargs)[source]#
Bases:
ModuleNegative cosine similarity objective.
This objective is used for instance in BYOL [Grill et al., 2020] or SimSiam [Chen and He, 2021].
- forward(z_i, z_j)[source]#
Compute the loss of the BYOL model.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed loss.
- Return type:
- stable_pretraining.losses.utils.off_diagonal(x)[source]#
Return a flattened view of the off-diagonal elements of a square matrix.
- stable_pretraining.losses.utils.sinkhorn_knopp(teacher_output: Tensor, teacher_temp: float, num_samples: int | Tensor, n_iterations: int = 3) Tensor[source]#
Sinkhorn-Knopp algorithm for optimal transport normalization.
This is an alternative to simple centering used in DINO/iBOT losses. It performs optimal transport to assign samples to prototypes, ensuring a more uniform distribution across prototypes.
Reference: DINOv3 implementation facebookresearch/dinov3
- Parameters:
teacher_output – Teacher predictions [batch, prototypes] or [n_samples, prototypes]
teacher_temp – Temperature for softmax
num_samples – Number of samples to assign. Can be: - int: Fixed batch size (e.g., batch_size * world_size for DINO) - torch.Tensor: Variable count (e.g., n_masked_patches for iBOT)
n_iterations – Number of Sinkhorn iterations (default: 3)
- Returns:
Normalized probabilities [batch, prototypes] summing to 1 over prototypes
Examples
# DINO CLS token loss (fixed batch size) Q = sinkhorn_knopp(teacher_cls_output, temp=0.04,
num_samples=batch_size * world_size)
# iBOT patch loss (variable number of masked patches) Q = sinkhorn_knopp(teacher_patch_output, temp=0.04,
num_samples=n_masked_patches_tensor)
Module contents#
SSL Losses.
This module provides various self-supervised learning loss functions organized by category: - DINO losses: Self-distillation methods (DINOLoss, iBOTPatchLoss) - Joint embedding losses: Contrastive and non-contrastive methods (BYOL, VICReg, Barlow Twins, SimCLR) - Reconstruction losses: Masked prediction methods (MAE) - Utilities: Helper functions (sinkhorn_knopp, off_diagonal, NegativeCosineSimilarity)
- class stable_pretraining.losses.BYOLLoss(*args, **kwargs)[source]#
Bases:
ModuleNormalized MSE objective used in BYOL [Grill et al., 2020].
Computes the mean squared error between L2-normalized online predictions and L2-normalized target projections.
- class stable_pretraining.losses.BarlowTwinsLoss(lambd: float = 0.005)[source]#
Bases:
ModuleSSL objective used in Barlow Twins [Zbontar et al., 2021].
- Parameters:
lambd (float, optional) – The weight of the off-diagonal terms in the loss. Default is 5e-3.
- forward(z_i, z_j)[source]#
Compute the loss of the Barlow Twins model.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed loss.
- Return type:
- class stable_pretraining.losses.CLIPLoss(temperature: float = 0.07)[source]#
Bases:
InfoNCELossCLIP loss (symmetric bidirectional InfoNCE).
As used in CLIP [Radford et al., 2021]. Computes symmetric cross-entropy over image-text and text-image logits.
- Parameters:
temperature (float, optional) – Softmax temperature. Default is 0.07. (If you use a learnable logit_scale in your model, pass it to forward(…) and this temperature will be ignored.)
- forward(feats_i: Tensor, feats_j: Tensor, logit_scale: Tensor | float | None = None) Tensor[source]#
Computes the contrastive loss.
- Parameters:
anchors (torch.Tensor) – The primary set of embeddings (queries) of shape [N, D].
candidates (torch.Tensor) – The set of embeddings to contrast against (keys) of shape [M, D].
targets (torch.Tensor) – A 1D tensor of ground-truth indices of shape [N], where targets[i] is the index of the positive candidate for anchors[i].
mask (torch.Tensor, optional) – A boolean mask of shape [N, M] to exclude certain anchor-candidate pairs from the loss calculation. Values set to True will be ignored.
logit_scale (torch.Tensor | float, optional) – The temperature scaling factor. Default is self.temperature.
- Returns:
A scalar loss value.
- Return type:
- class stable_pretraining.losses.DINOv1Loss(temperature_student: float = 0.1, center_momentum: float = 0.9)[source]#
Bases:
ModuleDINOv1 loss for self-distillation with cross-entropy [Caron et al., 2021].
This loss computes cross-entropy between teacher and student logits after applying temperature scaling and normalization. The teacher uses either classical centering or Sinkhorn-Knopp normalization to prevent mode collapse.
- Usage:
```python dino_loss = DINOv1Loss()
# Get logits from prototype layer student_logits = prototype_layer(student_embeddings) # [n_views, B, out_dim] teacher_logits = prototype_layer(teacher_embeddings) # [n_views, B, out_dim]
# Approach 1: Classical centering (recommended, faster) teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp=0.04) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Queue async center update
# Approach 2: Sinkhorn-Knopp (more principled, slower, no centering needed) n_views, batch_size, _ = teacher_logits.shape num_samples = n_views * batch_size # Total samples across views teacher_probs = dino_loss.sinkhorn_knopp_teacher(
teacher_logits, temp=0.04, num_samples=num_samples
) loss = dino_loss(student_logits, teacher_probs) # No update_center() needed for Sinkhorn-Knopp! ```
- Parameters:
- apply_center_update()[source]#
Apply the queued center update with EMA.
FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.
Waits for async all-reduce to complete and updates self.center with EMA. Automatically called by softmax_center_teacher() if update_centers=True.
- forward(student_logits: Tensor, teacher_probs: Tensor) Tensor[source]#
Compute DINO cross-entropy loss.
This is a pure loss computation with no side effects (no centering, no updates). Teacher probabilities should be pre-processed with softmax_center_teacher() or sinkhorn_knopp_teacher(). Center updates should be done separately with update_center().
- Parameters:
student_logits – Student logits [n_views, batch_size, out_dim]
teacher_probs – Teacher probabilities (already normalized) [n_views, batch_size, out_dim]
- Returns:
Scalar DINO loss value (cross-entropy averaged over view pairs, excluding diagonal)
- Shape:
student_logits: (S, B, K) where S = student views, B = batch size, K = out_dim
teacher_probs: (T, B, K) where T = teacher views
output: scalar
- sinkhorn_knopp_teacher(teacher_logits, teacher_temp, num_samples=None, n_iterations=3)[source]#
Apply Sinkhorn-Knopp optimal transport normalization to teacher logits.
FOR SINKHORN-KNOPP APPROACH ONLY. DOES NOT USE CENTER.
This method applies sinkhorn-knopp to enforce exact uniform distribution across prototypes without using centering. More principled than centering but more expensive. Used in SwAV and DINOv3 for better theoretical guarantees.
Note: When using Sinkhorn-Knopp, you do NOT need to call update_center() or apply_center_update() since centering is not used.
- Parameters:
teacher_logits – Teacher logits [*, out_dim]. Can be any shape as long as last dim is out_dim. Common shapes: [batch, out_dim] or [n_views, batch, out_dim]
teacher_temp – Temperature for softmax
num_samples – Total number of samples across all GPUs (int or tensor). If None, inferred from shape assuming [batch, out_dim] format. For multi-view [n_views, batch, out_dim], pass n_views * batch explicitly.
n_iterations – Number of Sinkhorn iterations (default: 3)
- Returns:
Teacher probabilities [same shape as input] with uniform prototype distribution
- softmax_center_teacher(teacher_logits, teacher_temp, update_centers=True)[source]#
Apply classical centering and sharpening to teacher logits.
FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.
This method subtracts the center (EMA of batch means) from teacher logits before applying softmax. This prevents mode collapse by ensuring balanced prototype usage.
- update_center(teacher_output)[source]#
Queue async center update from teacher logits.
FOR CLASSICAL CENTERING APPROACH ONLY. NOT NEEDED FOR SINKHORN-KNOPP.
Starts an asynchronous all-reduce for distributed training. The update is applied later when softmax_center_teacher() is called with update_centers=True. This allows the all-reduce to overlap with backward pass for efficiency.
- Typical usage:
teacher_probs = dino_loss.softmax_center_teacher(teacher_logits, temp) loss = dino_loss(student_logits, teacher_probs) dino_loss.update_center(teacher_logits) # Start async update # … backward pass happens here, overlapping with all-reduce … # Next iteration: softmax_center_teacher() will call apply_center_update()
- Parameters:
teacher_output – Teacher logits [n_views, batch_size, out_dim]
- class stable_pretraining.losses.DINOv2Loss(dino_loss_weight: float = 1.0, ibot_loss_weight: float = 1.0, temperature_student: float = 0.1, center_momentum: float = 0.9, student_temp: float = 0.1)[source]#
Bases:
ModuleDINOv2 loss combining CLS token and masked patch losses.
DINOv2 combines two losses: - DINOv1Loss: CLS token distillation (global views) - uses Sinkhorn-Knopp - iBOTPatchLoss: Masked patch prediction - uses Sinkhorn-Knopp
Both losses use Sinkhorn-Knopp normalization in DINOv2.
- Parameters:
dino_loss_weight (float) – Weight for CLS token loss. Default is 1.0.
ibot_loss_weight (float) – Weight for iBOT patch loss. Default is 1.0.
temperature_student (float) – Temperature for student softmax in DINO. Default is 0.1.
center_momentum (float) – EMA momentum for DINO centering (not used by iBOT). Default is 0.9.
student_temp (float) – Temperature for student softmax in iBOT. Default is 0.1.
- forward(student_cls_logits: Tensor, teacher_cls_probs: Tensor, student_patch_logits: Tensor = None, teacher_patch_probs: Tensor = None) Tensor[source]#
Compute combined DINOv2 loss.
- Parameters:
student_cls_logits – Student CLS logits [n_views, batch, out_dim]
teacher_cls_probs – Teacher CLS probs [n_views, batch, out_dim]
student_patch_logits – Student patch logits [n_masked_total, patch_out_dim] or None
teacher_patch_probs – Teacher patch probs [n_masked_total, patch_out_dim] or None
- Returns:
Combined weighted loss
- class stable_pretraining.losses.NTXEntLoss(temperature: float = 0.5)[source]#
Bases:
InfoNCELossNormalized temperature-scaled cross entropy loss.
Introduced in the SimCLR paper [Chen et al., 2020]. Also used in MoCo [He et al., 2020].
- Parameters:
temperature (float, optional) – The temperature scaling factor. Default is 0.5.
- forward(z_i: Tensor, z_j: Tensor) Tensor[source]#
Compute the NT-Xent loss.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed contrastive loss.
- Return type:
- class stable_pretraining.losses.NegativeCosineSimilarity(*args, **kwargs)[source]#
Bases:
ModuleNegative cosine similarity objective.
This objective is used for instance in BYOL [Grill et al., 2020] or SimSiam [Chen and He, 2021].
- forward(z_i, z_j)[source]#
Compute the loss of the BYOL model.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed loss.
- Return type:
- class stable_pretraining.losses.VICRegLoss(sim_coeff: float = 25, std_coeff: float = 25, cov_coeff: float = 1, epsilon: float = 0.0001)[source]#
Bases:
ModuleSSL objective used in VICReg [Bardes et al., 2021].
- Parameters:
sim_coeff (float, optional) – The weight of the similarity loss (attractive term). Default is 25.
std_coeff (float, optional) – The weight of the standard deviation loss. Default is 25.
cov_coeff (float, optional) – The weight of the covariance loss. Default is 1.
epsilon (float, optional) – Small value to avoid division by zero. Default is 1e-4.
- forward(z_i, z_j)[source]#
Compute the loss of the VICReg model.
- Parameters:
z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.
z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.
- Returns:
The computed loss.
- Return type:
- class stable_pretraining.losses.iBOTPatchLoss(student_temp: float = 0.1)[source]#
Bases:
ModuleiBOT patch-level prediction loss for masked patch prediction.
This loss computes cross-entropy between teacher and student patch predictions for masked patches only. Uses Sinkhorn-Knopp normalization exclusively (as in DINOv2/v3) to prevent mode collapse.
- Parameters:
student_temp (float) – Temperature for student softmax. Default is 0.1.
- forward(student_patch_logits, teacher_patch_probs)[source]#
Compute iBOT cross-entropy loss for masked patches.
- Parameters:
student_patch_logits – Student patch logits [n_masked_total, patch_out_dim]
teacher_patch_probs – Teacher probabilities [n_masked_total, patch_out_dim]
- Returns:
Scalar iBOT loss value
- sinkhorn_knopp_teacher(teacher_patch_tokens, teacher_temp, num_samples=None, n_iterations=3)[source]#
Apply Sinkhorn-Knopp optimal transport normalization to teacher patch logits.
This method applies optimal transport to enforce exact uniform distribution across prototypes. Used exclusively in DINOv2/v3 for iBOT patch loss.
- Parameters:
teacher_patch_tokens – Teacher patch logits [n_masked, patch_out_dim]
teacher_temp – Temperature for softmax
num_samples – Total number of masked patches across all GPUs (int or tensor). If None, inferred from shape.
n_iterations – Number of Sinkhorn iterations (default: 3)
- Returns:
Teacher probabilities [n_masked, patch_out_dim] with uniform prototype distribution
- stable_pretraining.losses.mae(target, pred, mask, norm_pix_loss=False)[source]#
Compute masked autoencoder loss.
- Parameters:
target – [N, L, p*p*3] target images
pred – [N, L, p*p*3] predicted images
mask – [N, L], 0 is keep, 1 is remove
norm_pix_loss – whether to normalize pixels
- Returns:
mean loss value
- Return type:
loss
- stable_pretraining.losses.off_diagonal(x)[source]#
Return a flattened view of the off-diagonal elements of a square matrix.
- stable_pretraining.losses.sinkhorn_knopp(teacher_output: Tensor, teacher_temp: float, num_samples: int | Tensor, n_iterations: int = 3) Tensor[source]#
Sinkhorn-Knopp algorithm for optimal transport normalization.
This is an alternative to simple centering used in DINO/iBOT losses. It performs optimal transport to assign samples to prototypes, ensuring a more uniform distribution across prototypes.
Reference: DINOv3 implementation facebookresearch/dinov3
- Parameters:
teacher_output – Teacher predictions [batch, prototypes] or [n_samples, prototypes]
teacher_temp – Temperature for softmax
num_samples – Number of samples to assign. Can be: - int: Fixed batch size (e.g., batch_size * world_size for DINO) - torch.Tensor: Variable count (e.g., n_masked_patches for iBOT)
n_iterations – Number of Sinkhorn iterations (default: 3)
- Returns:
Normalized probabilities [batch, prototypes] summing to 1 over prototypes
Examples
# DINO CLS token loss (fixed batch size) Q = sinkhorn_knopp(teacher_cls_output, temp=0.04,
num_samples=batch_size * world_size)
# iBOT patch loss (variable number of masked patches) Q = sinkhorn_knopp(teacher_patch_output, temp=0.04,
num_samples=n_masked_patches_tensor)