ConvMixer

ConvMixer#

class stable_pretraining.backbone.ConvMixer(in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7)[source]#

Bases: Module

ConvMixer model.

A simple and efficient convolutional architecture that operates directly on patches.

Parameters:
  • in_channels (int, optional) – Number of input channels. Defaults to 3.

  • num_classes (int, optional) – Number of output classes. Defaults to 10.

  • dim (int, optional) – Hidden dimension size. Defaults to 64.

  • depth (int, optional) – Number of ConvMixer blocks. Defaults to 6.

  • kernel_size (int, optional) – Kernel size for depthwise convolution. Defaults to 9.

  • patch_size (int, optional) – Patch embedding size. Defaults to 7.

Note

Introduced in [Trockman and Kolter, 2022].

forward(xb)[source]#

Forward pass through the ConvMixer model.

Parameters:

xb (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).

Returns:

Output logits of shape (batch_size, num_classes).

Return type:

torch.Tensor