Source code for stable_pretraining.losses.multimodal
"""Multimodal SSL losses.
This module contains losses for multimodal self-supervised learning,
particularly for image-text contrastive learning like CLIP.
"""
import torch
from typing import Optional
from .joint_embedding import InfoNCELoss
[docs]
class CLIPLoss(InfoNCELoss):
"""CLIP loss (symmetric bidirectional InfoNCE).
As used in CLIP :cite:`radford2021learning`.
Computes symmetric cross-entropy over image-text and text-image logits.
Args:
temperature (float, optional): Softmax temperature. Default is 0.07.
(If you use a learnable logit_scale in your model, pass it to
forward(...) and this temperature will be ignored.)
"""
def __init__(self, temperature: float = 0.07):
super().__init__(temperature=temperature)
[docs]
def forward(
self,
feats_i: torch.Tensor,
feats_j: torch.Tensor,
logit_scale: Optional[torch.Tensor | float] = None,
) -> torch.Tensor:
# for CLIP, targets are always the diagonal
targets = torch.arange(feats_i.size(0), device=feats_i.device)
# calculate loss in both directions
loss_i = self._compute(
anchors=feats_i,
candidates=feats_j,
targets=targets,
logit_scale=logit_scale,
)
loss_j = self._compute(
anchors=feats_j,
candidates=feats_i,
targets=targets,
logit_scale=logit_scale,
)
return 0.5 * (loss_i + loss_j)