Source code for stable_pretraining.backbone.resnet9

import torch
from torch import nn


[docs] class ResidualBlock(nn.Module): """A residual block as defined by He et al.""" def __init__(self, in_channels, out_channels, kernel_size, padding, stride): super(ResidualBlock, self).__init__() self.conv_res1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False, ) self.conv_res1_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9) self.conv_res2 = nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, bias=False, ) self.conv_res2_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9) if stride != 1: # in case stride is not set to 1, we need to downsample the residual so that # the dimensions are the same when we add them together self.downsample = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(num_features=out_channels, momentum=0.9), ) else: self.downsample = None self.relu = nn.ReLU(inplace=False)
[docs] def forward(self, x): residual = x out = self.relu(self.conv_res1_bn(self.conv_res1(x))) out2 = self.conv_res2_bn(self.conv_res2(out)) if self.downsample is not None: residual = self.downsample(residual) return self.relu(out2) + residual
[docs] class MLP(torch.nn.Sequential): """This block implements the multi-layer perceptron (MLP) module. Args: in_channels (int): Number of channels of the input hidden_channels (List[int]): List of the hidden channel dimensions norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. bias (bool): Whether to use bias in the linear layer. Default ``True`` dropout (float): The probability for the dropout layer. Default: 0.0 """ def __init__( self, in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=torch.nn.ReLU, inplace: bool = None, bias: bool = True, dropout: float = 0.0, ): # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: if in_dim is None: layers.append(torch.nn.LazyLinear(hidden_dim, bias=bias)) else: layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) if norm_layer == "batch_norm": layers.append(torch.nn.BatchNorm1d(hidden_dim)) layers.append(activation_layer(**params)) layers.append(torch.nn.Dropout(dropout, **params)) in_dim = hidden_dim layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) layers.append(torch.nn.Dropout(dropout, **params)) super().__init__(*layers)
[docs] class Resnet9(nn.Module): """A Residual network.""" def __init__(self, num_classes, num_channels, *args, **kwargs): super(Resnet9, self).__init__() self.conv = nn.Sequential( nn.Conv2d( in_channels=num_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(num_features=64, momentum=0.9), nn.ReLU(inplace=False), nn.Conv2d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(num_features=128, momentum=0.9), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=2, stride=2), ResidualBlock( in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, ), nn.Conv2d( in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(num_features=256, momentum=0.9), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d( in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(num_features=256, momentum=0.9), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=2, stride=2), ResidualBlock( in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, ), nn.AdaptiveMaxPool2d((2, 2)), ) self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)
[docs] def forward(self, x): out = self.conv(x).flatten(1) return self.fc(out)