"""Template classes to easily instantiate Supervised or SSL trainers."""
#
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
# Randall Balestriero <randallbalestriero@gmail.com>
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from .base import BaseTrainer
from .modules import TeacherStudentModule
from .utils import compute_global_mean, log_and_raise
# ==========================================
# Base trainers that require a loss function
# ==========================================
[docs]
class SupervisedTrainer(BaseTrainer):
r"""Base class for training a supervised model."""
required_modules = {"backbone": torch.nn.Module}
[docs]
def forward(self, *args, **kwargs):
"""Forward pass. By default, it simply calls the 'backbone' module."""
return self.module["backbone"](*args, **kwargs)
[docs]
def predict(self):
"""Call the forward pass of current batch."""
return self.forward(self.batch[0])
[docs]
def compute_loss(self):
"""Compute the loss of the model using the `loss` provided in the config."""
if self.loss is None:
log_and_raise(
ValueError,
f"When using the trainer {self.__class__.__name__}, "
"one needs to either provide a loss function in the config "
"or implement a custom `compute_loss` method.",
)
loss = self.loss(self.predict(), self.batch[1])
return {"loss": loss}
[docs]
class JointEmbeddingTrainer(BaseTrainer):
r"""Base class for training a joint-embedding SSL model."""
required_modules = {
"backbone": torch.nn.Module,
"projector": torch.nn.Module,
"backbone_classifier": torch.nn.Module,
"projector_classifier": torch.nn.Module,
}
def format_views_labels(self):
if (
len(self.batch) == 2
and torch.is_tensor(self.batch[1])
and not torch.is_tensor(self.batch[0])
):
# we assume the second element are the labels
views, labels = self.batch
elif (
len(self.batch) > 1
and all([torch.is_tensor(b) for b in self.batch])
and len({b.ndim for b in self.batch}) == 1
):
# we assume all elements are views
views = self.batch
labels = None
else:
msg = """You are using the JointEmbedding class with only 1 view!
Make sure to double check your config and datasets definition.
Most methods expect 2 views, some can use more."""
log_and_raise(ValueError, msg)
return views, labels
[docs]
def forward(self, *args, **kwargs):
"""Forward pass. By default, it simply calls the 'backbone' module."""
return self.module["backbone"](*args, **kwargs)
[docs]
def predict(self):
"""Call the backbone classifier on the forward pass of current batch."""
return self.module["backbone_classifier"](self.forward(self.batch[0]))
[docs]
def compute_loss(self):
"""Compute final loss as sum of SSL loss and classifier losses."""
if self.loss is None:
log_and_raise(
ValueError,
f"When using the trainer {self.__class__.__name__}, "
"one needs to either provide a loss function in the config "
"or implement a custom `compute_loss` method.",
)
views, labels = self.format_views_labels()
representations = [self.module["backbone"](view) for view in views]
self._latest_representations = representations
embeddings = [self.module["projector"](rep) for rep in representations]
self._latest_embeddings = embeddings
loss_ssl = self.loss(*embeddings)
classifier_losses = self.compute_loss_classifiers(
representations, embeddings, labels
)
return {"loss_ssl": loss_ssl, **classifier_losses}
[docs]
def compute_loss_classifiers(self, representations, embeddings, labels):
"""Compute the classifier loss for both backbone and projector."""
loss_backbone_classifier = 0
loss_projector_classifier = 0
# Inputs are detached to avoid backprop through backbone and projector.
if labels is not None:
for rep, embed in zip(representations, embeddings):
loss_backbone_classifier += F.cross_entropy(
self.module["backbone_classifier"](rep.detach()), labels
)
loss_projector_classifier += F.cross_entropy(
self.module["projector_classifier"](embed.detach()), labels
)
return {
"loss_backbone_classifier": loss_backbone_classifier,
"loss_projector_classifier": loss_projector_classifier,
}
@property
def latest_embeddings(self):
if not hasattr(self, "_latest_embeddings"):
return None
return self._latest_embeddings
@latest_embeddings.setter
def latest_embeddings(self, value):
self._latest_embeddings = value
@property
def latest_representations(self):
if not hasattr(self, "_latest_representations"):
return None
return self._latest_representations
@latest_representations.setter
def latest_representations(self, value):
self._latest_representations = value
[docs]
class SelfDistillationTrainer(JointEmbeddingTrainer):
r"""Base class for training a self-distillation SSL model."""
required_modules = {
"backbone": TeacherStudentModule,
"projector": TeacherStudentModule,
"backbone_classifier": torch.nn.Module,
"projector_classifier": torch.nn.Module,
}
[docs]
def compute_loss(self):
"""Compute final loss as sum of SSL loss and classifier losses."""
if self.loss is None:
log_and_raise(
ValueError,
f"When using the trainer {self.__class__.__name__}, "
"one needs to either provide a loss function in the config "
"or implement a custom `compute_loss` method.",
)
views, labels = self.format_views_labels()
representations_student = [
self.module["backbone"].forward_student(view) for view in views
]
embeddings_student = [
self.module["projector"].forward_student(rep)
for rep in representations_student
]
# If a predictor is used, it is applied to the student embeddings.
if "predictor" in self.module:
embeddings_student = [
self.module["predictor"](embed) for embed in embeddings_student
]
representations_teacher = [
self.module["backbone"].forward_teacher(view) for view in views
]
self.latest_representations = representations_teacher
embeddings_teacher = [
self.module["projector"].forward_teacher(rep)
for rep in representations_teacher
]
self.latest_embeddings = embeddings_teacher
loss_ssl = 0.5 * (
self.loss(embeddings_student[0], embeddings_teacher[1])
+ self.loss(embeddings_student[1], embeddings_teacher[0])
)
classifier_losses = self.compute_loss_classifiers(
representations_teacher, embeddings_teacher, labels
)
return {"loss_ssl": loss_ssl, **classifier_losses}
[docs]
class JointEmbeddingPredictiveTrainer(BaseTrainer):
r"""Base class for training a joint-embedding predictive architecture."""
required_modules = {
"context_encoder": torch.nn.Module,
"target_encoder": torch.nn.Module,
"predictor": torch.nn.Module,
}
def format_context_target(self):
if len(self.batch) == 2:
(context, context_transforms), (target, target_transforms) = self.batch
else:
raise ValueError("JointEmbeddingPredictiveTrainer requires 2 views.")
return (context, context_transforms), (target, target_transforms)
[docs]
def forward(self, *args, **kwargs):
"""Forward pass of the context encoder."""
return self.module["context_encoder"](*args, **kwargs)
[docs]
def forward_target(self, *args, **kwargs):
"""Forward pass of the target encoder."""
return self.module["target_encoder"](*args, **kwargs)
[docs]
def forward_predictor(self, *args, **kwargs):
"""Forward pass of the predictor, that transforms the context latents into the target latents."""
return self.module["predictor"](*args, **kwargs)
[docs]
def compute_loss(self):
"""Compute the final loss as the L1 distance between the predicted and target latents."""
(
(context, context_transformation_parameters),
(target, target_transformation_parameters),
) = self.format_context_target()
context_representations = self.forward(
context, context_transformation_parameters
)
self._latest_representations = context_representations
target_representations = self.forward_target(
target, target_transformation_parameters
)
# NOTE The forward predictor can take some additional arguments, such as the transformation arguments.
predicted_representations = self.forward_predictor(
context_representations,
target_representations,
context_transformation_parameters,
target_transformation_parameters,
)
self._latest_embeddings = predicted_representations
loss = self.loss(predicted_representations, target_representations)
return {"loss_ssl": loss}
@property
def latest_embeddings(self):
if not hasattr(self, "_latest_embeddings"):
return None
return self._latest_embeddings
@latest_embeddings.setter
def latest_embeddings(self, value):
self._latest_embeddings = value
@property
def latest_representations(self):
if not hasattr(self, "_latest_representations"):
return None
return self._latest_representations
@latest_representations.setter
def latest_representations(self, value):
self._latest_representations = value
# ===============================
# Trainers with Specific Losses
# ===============================
[docs]
class DINOTrainer(SelfDistillationTrainer):
r"""DINO SSL model by :cite:`caron2021emerging`.
Parameters
----------
warmup_temperature_teacher : float, optional
The initial temperature for the teacher output.
Default is 0.04.
temperature_teacher : float, optional
The temperature for the teacher output.
Default is 0.07.
warmup_epochs_temperature_teacher : int, optional
The number of epochs to warm up the teacher temperature.
Default is 30.
temperature_student : float, optional
The temperature for the student output.
Default is 0.1.
center_momentum : float, optional
The momentum used to update the center.
Default is 0.9.
**kwargs
Additional arguments passed to the base class.
"""
def __init__(
self,
warmup_temperature_teacher: float = 0.04,
temperature_teacher: float = 0.07,
warmup_epochs_temperature_teacher: int = 30,
temperature_student: float = 0.1,
center_momentum: float = 0.9,
**kwargs,
):
super().__init__(
warmup_temperature_teacher=warmup_temperature_teacher,
temperature_teacher=temperature_teacher,
warmup_epochs_temperature_teacher=warmup_epochs_temperature_teacher,
temperature_student=temperature_student,
center_momentum=center_momentum,
**kwargs,
)
self.temperature_teacher_schedule = torch.linspace(
start=warmup_temperature_teacher,
end=temperature_teacher,
steps=warmup_epochs_temperature_teacher,
)
[docs]
def compute_loss(self):
"""Compute the DINO loss."""
views, labels = self.format_views_labels()
representations_student = [
self.module["backbone"].forward_student(view) for view in views
]
embeddings_student = [
self.module["projector"].forward_student(rep)
for rep in representations_student
]
# Construct target *from global views only* with the target ('teacher') network.
with torch.no_grad():
global_views = self.batch[0][:2] # First two views should be global views.
representations_teacher = [
self.module["backbone"].forward_teacher(view) for view in global_views
]
self.latest_representations = representations_teacher
embeddings_teacher = [
self.module["projector"].forward_teacher(rep)
for rep in representations_teacher
]
self.latest_embeddings = embeddings_teacher
if self.epoch < self.warmup_epochs_temperature_teacher:
temperature_teacher = self.temperature_teacher_schedule[self.epoch]
else:
temperature_teacher = self.temperature_teacher
stacked_embeddings_teacher = torch.stack(embeddings_teacher)
if hasattr(self, "center"):
probs_teacher = F.softmax(
(stacked_embeddings_teacher - self.center) / temperature_teacher,
dim=-1,
)
else:
probs_teacher = F.softmax(
stacked_embeddings_teacher / temperature_teacher, dim=-1
)
stacked_embeddings_student = torch.stack(embeddings_student)
log_probs_student = F.log_softmax(
stacked_embeddings_student / self.temperature_student, dim=-1
)
# Compute the cross entropy loss between the student and teacher probabilities.
probs_teacher_flat = probs_teacher.flatten(start_dim=1)
log_probs_student_flat = log_probs_student.flatten(start_dim=1)
loss_ssl = -probs_teacher_flat @ log_probs_student_flat.T
loss_ssl.fill_diagonal_(0)
# Normalize the loss.
n_terms = loss_ssl.numel() - loss_ssl.diagonal().numel()
batch_size = stacked_embeddings_teacher.shape[1]
loss_ssl = loss_ssl.sum() / (n_terms * batch_size)
# Update the center of the teacher network.
with torch.no_grad():
batch_center = compute_global_mean(stacked_embeddings_teacher, dim=(0, 1))
if not hasattr(self, "center"):
self.center = batch_center
else:
self.center = self.center * self.center_momentum + batch_center * (
1 - self.center_momentum
)
classifier_losses = self.compute_loss_classifiers(
representations_teacher, embeddings_teacher, labels
)
return {"loss_ssl": loss_ssl, **classifier_losses}