Source code for stable_ssl.joint_embedding.simclr

# -*- coding: utf-8 -*-
"""SimCLR model."""
#
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#         Randall Balestriero <randallbalestriero@gmail.com>
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
import torch
import torch.nn.functional as F

from .base import JEConfig, JETrainer


[docs] class SimCLR(JETrainer): """SimCLR model from [CKNH20]_. Reference --------- .. [CKNH20] Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A Simple Framework for Contrastive Learning of Visual Representations. In International Conference on Machine Learning (pp. 1597-1607). PMLR. """
[docs] def compute_ssl_loss(self, h_i, h_j): """Compute the contrastive loss for SimCLR. Parameters ---------- h_i : torch.Tensor Latent representation of the first augmented view of the batch. h_j : torch.Tensor Latent representation of the second augmented view of the batch. Returns ------- float The computed contrastive loss. """ z = torch.cat([h_i, h_j], 0) N = z.size(0) * self.config.hardware.world_size features = F.normalize(z, dim=1) sim = torch.matmul(features, features.T) / self.config.model.temperature sim_i_j = torch.diag(sim, N // 2) sim_j_i = torch.diag(sim, -N // 2) positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0) # shape (N) mask = torch.eye(N, dtype=bool).to(self.this_device) negative_samples = sim[~mask].reshape(N, -1) # shape (N, N-1) attraction = -positive_samples.mean() repulsion = torch.logsumexp(negative_samples, dim=1).mean() return attraction + repulsion
[docs] @dataclass class SimCLRConfig(JEConfig): """Configuration for the SimCLR model parameters. Parameters ---------- temperature : float Temperature parameter for the contrastive loss. Default is 0.15. """ temperature: float = 0.15
[docs] def trainer(self): """Return the corresponding trainer for the SimCLR configuration. Returns ------- SimCLR A SimCLR trainer instance. """ return SimCLR