import torch
from torch import nn
class ResidualBlock(nn.Module):
"""A residual block as defined by He et al."""
def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
super(ResidualBlock, self).__init__()
self.conv_res1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=False,
)
self.conv_res1_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
self.conv_res2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
bias=False,
)
self.conv_res2_bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.9)
if stride != 1:
# in case stride is not set to 1, we need to downsample the residual so that
# the dimensions are the same when we add them together
self.downsample = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(num_features=out_channels, momentum=0.9),
)
else:
self.downsample = None
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
residual = x
out = self.relu(self.conv_res1_bn(self.conv_res1(x)))
out2 = self.conv_res2_bn(self.conv_res2(out))
if self.downsample is not None:
residual = self.downsample(residual)
return self.relu(out2) + residual
class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.
Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
"""
def __init__(
self,
in_channels: int,
hidden_channels: list[int],
norm_layer: str = None,
activation_layer=torch.nn.ReLU,
inplace: bool = None,
bias: bool = True,
dropout: float = 0.0,
):
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}
layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
if in_dim is None:
layers.append(torch.nn.LazyLinear(hidden_dim, bias=bias))
else:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer == "batch_norm":
layers.append(torch.nn.BatchNorm1d(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))
super().__init__(*layers)
[docs]
class Resnet9(nn.Module):
"""A Residual network."""
def __init__(self, num_classes, num_channels, *args, **kwargs):
super(Resnet9, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=num_channels,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=64, momentum=0.9),
nn.ReLU(inplace=False),
nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=128, momentum=0.9),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
ResidualBlock(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
),
nn.Conv2d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=256, momentum=0.9),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=256, momentum=0.9),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
ResidualBlock(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
),
nn.AdaptiveMaxPool2d((2, 2)),
)
self.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True)
[docs]
def forward(self, x):
out = self.conv(x).flatten(1)
return self.fc(out)