load_backbone

Contents

load_backbone#

stable_ssl.modules.load_backbone(name, num_classes, weights=None, low_resolution=False, return_feature_dim=False, **kwargs)[source]#

Load a backbone model.

If num_classes is provided, the last layer is replaced by a linear layer of output size num_classes. Otherwise, the last layer is replaced by an identity layer.

Parameters:
  • name (str) – Name of the backbone model. Supported models are: - Any model from torchvision.models - “Resnet9” - “ConvMixer”

  • num_classes (int) – Number of classes in the dataset. If None, the model is loaded without the classifier.

  • weights (bool, optional) – Whether to load a weights model, by default False.

  • low_resolution (bool, optional) – Whether to adapt the resolution of the model (for CIFAR typically). By default False.

  • return_feature_dim (bool, optional) – Whether to return the feature dimension of the model.

  • **kwargs (dict) – Additional keyword arguments for the model.

Returns:

The neural network model.

Return type:

torch.nn.Module