Source code for stable_ssl.modules

"""Neural network modules."""

#
# Author: Randall Balestriero <randallbalestriero@gmail.com>
#         Hugues Van Assel <vanasselhugues@gmail.com>
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
import math

import torch
import torch.nn as nn
import torchvision

from .utils import log_and_raise


[docs] def load_backbone( name, num_classes=None, weights=None, low_resolution=False, return_feature_dim=False, **kwargs, ): """Load a backbone model. If num_classes is provided, the last layer is replaced by a linear layer of output size num_classes. Otherwise, the last layer is replaced by an identity layer. Parameters ---------- name : str Name of the backbone model. Supported models are: - Any model from torchvision.models - "Resnet9" - "ConvMixer" num_classes : int Number of classes in the dataset. If None, the model is loaded without the classifier. By default None. weights : bool, optional Whether to load a weights model, by default False. low_resolution : bool, optional Whether to adapt the resolution of the model (for CIFAR typically). By default False. return_feature_dim : bool, optional Whether to return the feature dimension of the model. **kwargs: dict Additional keyword arguments for the model. Returns ------- torch.nn.Module The neural network model. """ # Load the name. if name == "resnet9": model = Resnet9(**kwargs) elif name == "ConvMixer": model = ConvMixer(**kwargs) else: try: model = torchvision.models.__dict__[name](weights=weights, **kwargs) except KeyError: raise ValueError(f"Unknown model: {name}.") # Adapt the last layer, either linear or identity. def last_layer(num_classes, in_features): if num_classes is not None: return nn.Linear(in_features, num_classes) else: return nn.Identity() # For models like ResNet. if hasattr(model, "fc"): in_features = model.fc.in_features model.fc = last_layer(num_classes, in_features) # For models like VGG or AlexNet. elif hasattr(model, "classifier"): in_features = model.classifier[-1].in_features model.classifier[-1] = last_layer(num_classes, in_features) # For models like ViT. elif hasattr(model, "heads"): in_features = model.heads.head.in_features model.heads.head = last_layer(num_classes, in_features) # For models like Swin Transformer. elif hasattr(model, "head"): in_features = model.head.in_features model.head = last_layer(num_classes, in_features) else: raise ValueError(f"Unknown model structure for : '{name}'.") if low_resolution: # reduce resolution, for instance for CIFAR if hasattr(model, "conv1"): in_channels = kwargs.get("in_channels", 3) model.conv1 = nn.Conv2d( in_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) model.maxpool = nn.Identity() else: logging.warning(f"Cannot adapt resolution for model: {name}.") if return_feature_dim: return model, in_features else: return model
class TeacherStudentModule(nn.Module): """Student network and its teacher network updated as an EMA of the student network. The teacher model is updated by taking a running average of the student’s parameters and buffers. When `ema_coefficient == 0.0`, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients. Parameters ---------- student : torch.nn.Module The student model whose parameters will be tracked. warm_init : bool, optional If True, performs an initialization step to match the student’s parameters immediately. Default is True. base_ema_coefficient : float, optional EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996. final_ema_coefficient : float, optional EMA decay factor at the end of training. Default is 1. """ def __init__( self, student: nn.Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1, ): if not (0.0 <= base_ema_coefficient <= 1.0) or not ( 0.0 <= final_ema_coefficient <= 1.0 ): log_and_raise( ValueError, f"ema_coefficient must be in [0, 1]. Found: " f"base_ema_coefficient={base_ema_coefficient}, " "final_ema_coefficient={final_ema_coefficient}.", ) super().__init__() self.student = student self.base_ema_coefficient = torch.Tensor([base_ema_coefficient])[0] self.final_ema_coefficient = torch.Tensor([final_ema_coefficient])[0] if self.base_ema_coefficient == 0.0 and self.final_ema_coefficient == 0.0: # No need to create a teacher network if the EMA coefficient is 0.0. self.teacher = student else: # Create a teacher network with the same architecture as the student. self.teacher = copy.deepcopy(student) self.teacher.requires_grad_(False) # Teacher should not require gradients. if warm_init: # Initialization step to match the student’s parameters. self.ema_coefficient = torch.zeros(()) self.update_teacher() self.ema_coefficient = self.base_ema_coefficient.clone() @torch.no_grad def update_teacher(self): """Perform one EMA update step on the teacher’s parameters. The update rule is: teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param This is done in a `no_grad` context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable. Everything is updated, including buffers (e.g. batch norm running averages). """ if self.ema_coefficient.item() == 0.0: return # Nothing to update when the teacher is the student. elif self.ema_coefficient.item() == 1.0: return # No need to update when the teacher is fixed. for teacher_group, student_group in [ (self.teacher.parameters(), self.student.parameters()), (self.teacher.buffers(), self.student.buffers()), ]: for t, s in zip(teacher_group, student_group): ty = t.dtype t.mul_(self.ema_coefficient.to(dtype=ty)) t.add_((1.0 - self.ema_coefficient).to(dtype=ty) * s) @torch.no_grad def update_ema_coefficient(self, epoch: int, total_epochs: int): """Update the EMA coefficient following a cosine schedule. The EMA coefficient is updated following a cosine schedule: ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi)) Parameters ---------- epoch : int Current epoch in the training loop. total_epochs : int Total number of epochs in the training loop. """ self.ema_coefficient = self.final_ema_coefficient - 0.5 * ( self.final_ema_coefficient - self.base_ema_coefficient ) * (1 + math.cos(epoch / total_epochs * math.pi)) def forward_student(self, *args, **kwargs): """Forward pass through the student network. Gradients will flow normally.""" return self.student(*args, **kwargs) def forward_teacher(self, *args, **kwargs): """Forward pass through the teacher network. By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow. """ with torch.no_grad(): return self.teacher(*args, **kwargs) def forward(self, *args, **kwargs): """Forward pass through either the student or teacher network. You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that. """ return self.forward_teacher(*args, **kwargs)
[docs] class MLP(nn.Module): """Multi-layer perceptron.""" def __init__(self, sizes, activation="ReLU", batch_norm=True): super().__init__() layers = [] for i in range(len(sizes) - 2): layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=not batch_norm)) if batch_norm: layers.append(nn.BatchNorm1d(sizes[i + 1])) layers.append(nn.__dict__[activation]()) layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) self.layers = nn.Sequential(*layers)
[docs] def forward(self, x): """Forward pass.""" return self.layers(x)
[docs] class Resnet9(nn.Module): """ResNet-9 model.""" def __init__(self, in_channels=3, num_classes=10): super().__init__() self.conv1 = self.conv_block(in_channels, 64) self.conv2 = self.conv_block(64, 128, pool=True) self.res1 = nn.Sequential(self.conv_block(128, 128), self.conv_block(128, 128)) self.conv3 = self.conv_block(128, 256, pool=True) self.conv4 = self.conv_block(256, 512, pool=True) self.res2 = nn.Sequential(self.conv_block(512, 512), self.conv_block(512, 512)) self.conv5 = self.conv_block(512, 1028, pool=True) self.res3 = nn.Sequential( self.conv_block(1028, 1028), self.conv_block(1028, 1028) ) self.pool = nn.Sequential(nn.AdaptiveMaxPool2d(1), nn.Flatten()) self.fc = nn.Linear(1028, num_classes) @staticmethod def conv_block(in_channels, out_channels, pool=False): layers = [ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ] if pool: layers.append(nn.MaxPool2d(2)) return nn.Sequential(*layers)
[docs] def forward(self, xb): """Forward pass.""" out = self.conv1(xb) out = self.conv2(out) out = self.res1(out) + out out = self.conv3(out) out = self.conv4(out) out = self.res2(out) + out out = self.conv5(out) out = self.res3(out) + out out = self.fc(self.pool(out)) return out
[docs] class ConvMixer(nn.Module): """ConvMixer model from :cite:`trockman2022patches`.""" def __init__( self, in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7, ): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size), nn.BatchNorm2d(dim), nn.ReLU(), ) self.blocks_a = nn.ModuleList( [ nn.Sequential( nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), nn.BatchNorm2d(dim), nn.ReLU(), ) for _ in range(depth) ] ) self.blocks_b = nn.ModuleList( [ nn.Sequential( nn.Conv2d(dim, dim, kernel_size=1), nn.BatchNorm2d(dim), nn.ReLU() ) for _ in range(depth) ] ) self.pool = nn.Sequential(nn.AdaptiveMaxPool2d(1), nn.Flatten()) self.fc = nn.Linear(dim, num_classes)
[docs] def forward(self, xb): """Forward pass.""" out = self.conv1(xb) for a, b in zip(self.blocks_a, self.blocks_b): out = out + a(out) out = b(out) out = self.fc(self.pool(out)) return out