Source code for stable_pretraining.backbone.probe

import torch
import torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from .utils import register_lr_scale_hook
from typing import List, Union, Optional, Callable, Dict, Any
import torch.nn as nn
from loguru import logger


[docs] class MultiHeadAttentiveProbe(torch.nn.Module): """A multi-head attentive probe for sequence representations. This module applies multiple attention heads to a sequence of embeddings, pools the sequence into a fixed-size representation per head, concatenates the results, and projects to a set of output classes. Args: embedding_dim (int): Dimensionality of the input embeddings. num_classes (int): Number of output classes. num_heads (int, optional): Number of attention heads. Default is 4. Attributes: ln (torch.nn.LayerNorm): Layer normalization applied to the input. attn_vectors (torch.nn.Parameter): Learnable attention vectors for each head, shape (num_heads, embedding_dim). fc (torch.nn.Linear): Final linear layer mapping concatenated head outputs to class logits. Forward Args: x (torch.Tensor): Input tensor of shape (N, T, D), where N = batch size, T = sequence length, D = embedding_dim. Returns: torch.Tensor: Output logits of shape (N, num_classes). Example: >>> probe = MultiHeadAttentiveProbe( ... embedding_dim=128, num_classes=10, num_heads=4 ... ) >>> x = torch.randn(32, 20, 128) # batch of 32, sequence length 20 >>> logits = probe(x) # shape: (32, 10) """ def __init__(self, embedding_dim: int, num_classes: int, num_heads: int = 4): super().__init__() self.ln = torch.nn.LayerNorm(embedding_dim) self.attn_vectors = torch.nn.Parameter(torch.randn(num_heads, embedding_dim)) self.fc = torch.nn.Linear(embedding_dim * num_heads, num_classes)
[docs] def forward(self, x: torch.Tensor): # x: (N, T, D) x = self.ln(x) # Compute attention for each head: (N, num_heads, T) attn_scores = torch.einsum("ntd,hd->nht", x, self.attn_vectors) attn_weights = torch.softmax(attn_scores, dim=-1) # (N, num_heads, T) # Weighted sum for each head: (N, num_heads, D) pooled = torch.einsum("ntd,nht->nhd", x, attn_weights) pooled = pooled.reshape(x.size(0), -1) # (N, num_heads * D) out = self.fc(pooled) # (N, num_classes) return out
[docs] class LinearProbe(torch.nn.Module): """Linear using either CLS token or mean pooling with configurable normalization layer. Args: embedding_dim (int): Dimensionality of the input embeddings. num_classes (int): Number of output classes. pooling (str): Pooling strategy, either 'cls' or 'mean'. norm_layer (callable or None): Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features. Attributes: norm (nn.Module or None): Instantiated normalization layer, or None. fc (nn.Linear): Linear layer mapping pooled representation to class logits. Forward Args: x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D). If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization). Returns: torch.Tensor: Output logits of shape (N, num_classes). Example: >>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10) """ def __init__(self, embedding_dim, num_classes, pooling="cls", norm_layer=None): super().__init__() assert pooling in ( "cls", "mean", None, ), "pooling must be 'cls' or 'mean' or None" self.pooling = pooling self.norm = norm_layer(embedding_dim) if norm_layer is not None else None self.fc = torch.nn.Linear(embedding_dim, num_classes)
[docs] def forward(self, x): # x: (N, T, D) or (N, D) if x.ndim == 2: # (N, D): no pooling or normalization pooled = x elif self.pooling == "cls": pooled = x[:, 0, :] # (N, D) elif self.pooling == "mean": # 'mean' pooled = x.mean(dim=1) # (N, D) else: pooled = x.flatten(1) out = self.fc(self.norm(pooled)) # (N, num_classes) return out
[docs] class AutoLinearClassifier(torch.nn.Module): """Linear using either CLS token or mean pooling with configurable normalization layer. Args: embedding_dim (int): Dimensionality of the input embeddings. num_classes (int): Number of output classes. pooling (str): Pooling strategy, either 'cls' or 'mean'. norm_layer (callable or None): Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features. Attributes: norm (nn.Module or None): Instantiated normalization layer, or None. fc (nn.Linear): Linear layer mapping pooled representation to class logits. Forward Args: x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D). If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization). Returns: torch.Tensor: Output logits of shape (N, num_classes). Example: >>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10) """ def __init__( self, name, embedding_dim, num_classes, pooling=None, weight_decay=[0], lr_scaling=[1], normalization=["none", "norm", "bn"], dropout=[0, 0.5], label_smoothing=[0, 1], ): super().__init__() assert pooling in ( "cls", "mean", None, ), "pooling must be 'cls' or 'mean' or None" self.fc = torch.nn.ModuleDict() self.losses = torch.nn.ModuleDict() metrics = {} for lr in lr_scaling: for wd in weight_decay: for norm in normalization: for drop in dropout: for ls in label_smoothing: if norm == "bn": layer_norm = torch.nn.BatchNorm1d(embedding_dim) elif norm == "norm": layer_norm = torch.nn.LayerNorm(embedding_dim) else: assert norm == "none" layer_norm = torch.nn.Identity() id = f"{name}_{norm}_{drop}_{ls}_{lr}_{wd}".replace(".", "") self.fc[id] = torch.nn.Sequential( layer_norm, torch.nn.Dropout(drop), torch.nn.Linear(embedding_dim, num_classes), ) register_lr_scale_hook(self.fc[id], lr, wd) self.losses[id] = torch.nn.CrossEntropyLoss( label_smoothing=ls / num_classes ) metrics[id] = MulticlassAccuracy(num_classes) self.metrics = torchmetrics.MetricCollection( metrics, prefix="eval/", postfix="_top1" )
[docs] def forward(self, x, y=None, pl_module=None): # x: (N, T, D) or (N, D) if x.ndim == 2: # (N, D): no pooling or normalization pooled = x elif self.pooling == "cls": assert x.ndim == 3 pooled = x[:, 0, :] # (N, D) elif self.pooling == "mean": # 'mean' if x.ndim == 3: pooled = x.mean(dim=1) # (N, D) else: assert x.ndim == 4 pooled = x.mean(dim=(2, 3)) # (N, D) else: pooled = x.flatten(1) loss = {} for name in self.fc.keys(): yhat = self.fc[name](pooled) loss[f"train/{name}"] = self.losses[name](yhat, y) if not self.training: self.metrics[name].update(yhat, y) if self.training and pl_module: pl_module.log_dict(loss, on_step=True, on_epoch=False, rank_zero_only=True) elif pl_module: pl_module.log_dict( self.metrics, on_step=False, on_epoch=True, sync_dist=True, ) return sum(loss.values())
[docs] class AutoTuneMLP(nn.Module): """Automatically creates multiple MLP variants with different hyperparameter combinations. This module creates a grid of MLPs with different configurations (dropout, normalization, learning rates, architectures, etc.) to enable parallel hyperparameter tuning. Args: in_features: Number of input features out_features: Number of output features hidden_features: Architecture specification. Can be: - List[int]: Single architecture, e.g., [256, 128] - List[List[int]]: Multiple architectures, e.g., [[256, 128], [512, 256, 128]] - []: Empty list for linear model (no hidden layers) name: Base name for this AutoTuneMLP instance loss_fn: Loss function to compute loss additional_weight_decay: List of weight decay values to try lr_scaling: List of learning rate scaling factors to try normalization: List of normalization types ['none', 'norm', 'bn'] dropout: List of dropout rates to try activation: List of activation functions ['relu', 'leaky_relu', 'tanh'] Examples: >>> # Single architecture >>> model = AutoTuneMLP(128, 10, [256, 128], "clf", nn.CrossEntropyLoss()) >>> # Multiple architectures >>> model = AutoTuneMLP( ... 128, 10, [[256], [256, 128], [512, 256]], "clf", nn.CrossEntropyLoss() ... ) >>> # Linear model (no hidden layers) >>> model = AutoTuneMLP(128, 10, [], "linear_clf", nn.CrossEntropyLoss()) """ def __init__( self, in_features: int, out_features: int, hidden_features: Union[List[int], List[List[int]]], name: str, loss_fn: Callable, additional_weight_decay: Union[float, List[float]] = [0], lr_scaling: Union[float, List[float]] = [1], normalization: Union[str, List[str]] = ["none"], dropout: Union[float, List[float]] = [0], activation: Union[str, List[str]] = ["relu"], ): super().__init__() logger.info(f"Initializing AutoTuneMLP: {name}") logger.debug(f"Input features: {in_features}, Output features: {out_features}") self.mlp = nn.ModuleDict() self.in_features = in_features self.out_features = out_features self.loss_fn = loss_fn self.name = name # Normalize hidden_features to list of lists self.hidden_features = self._normalize_architectures(hidden_features) logger.debug(f"Architectures to try: {self.hidden_features}") # Store hyperparameter configurations self.lr_scaling = lr_scaling self.additional_weight_decay = additional_weight_decay self.normalization = normalization self.dropout = dropout self.activation = activation # Generate all MLP variants self._build_mlp_variants() logger.info(f"Created {len(self.mlp)} MLP variants for {name}") @staticmethod def _normalize_architectures( hidden_features: Union[List[int], List[List[int]]], ) -> List[List[int]]: """Normalize hidden_features to list of lists format. Args: hidden_features: Single architecture or list of architectures Returns: List of architecture configurations Examples: >>> _normalize_architectures([256, 128]) [[256, 128]] >>> _normalize_architectures([[256], [256, 128]]) [[256], [256, 128]] >>> _normalize_architectures([]) [[]] """ # Empty list means linear model if len(hidden_features) == 0: logger.info("Linear model configuration (no hidden layers)") return [[]] # Check if it's a list of lists or single list if isinstance(hidden_features[0], list): logger.info(f"Multiple architectures: {len(hidden_features)} variants") return hidden_features else: logger.info(f"Single architecture: {hidden_features}") return [hidden_features] def _build_mlp_variants(self) -> None: """Build all MLP variants based on hyperparameter grid.""" variant_count = 0 for arch_idx, arch in enumerate(self.hidden_features): arch_name = self._get_arch_name(arch, arch_idx) for lr in self._to_list(self.lr_scaling): for wd in self._to_list(self.additional_weight_decay): for norm in self._get_norm_layers(): for act in self._get_activation_layers(): for drop in self._to_list(self.dropout): # Create unique ID for this variant norm_name = self._get_layer_name(norm) act_name = self._get_layer_name(act) variant_id = ( f"{self.name}_{arch_name}_{norm_name}_{act_name}_" f"drop{drop}_lr{lr}_wd{wd}" ).replace(".", "_") logger.debug(f"Creating variant: {variant_id}") # Build MLP self.mlp[variant_id] = self._create_mlp( arch, drop, norm, act ) # Register learning rate and weight decay hooks self._register_lr_scale_hook( self.mlp[variant_id], lr, wd ) variant_count += 1 logger.info(f"Successfully built {variant_count} MLP variants") @staticmethod def _get_arch_name(architecture: List[int], index: int) -> str: """Get a readable name for an architecture. Args: architecture: List of hidden dimensions index: Architecture index Returns: String representation of architecture """ if len(architecture) == 0: return "linear" return f"arch{index}_" + "x".join(map(str, architecture)) @staticmethod def _to_list(value: Union[Any, List[Any]]) -> List[Any]: """Convert single value to list if needed.""" return value if isinstance(value, (list, tuple)) else [value] def _get_norm_layers(self) -> List[Optional[type]]: """Get list of normalization layer types.""" norm_map = { "bn": nn.BatchNorm1d, "norm": nn.LayerNorm, "none": None, None: None, } layers = [] for case in self._to_list(self.normalization): if case not in norm_map: logger.warning(f"Unknown normalization: {case}, skipping") continue layers.append(norm_map[case]) return layers if layers else [None] def _get_activation_layers(self) -> List[type]: """Get list of activation layer types.""" act_map = { "relu": nn.ReLU, "leaky_relu": nn.LeakyReLU, "tanh": nn.Tanh, None: nn.Identity, } layers = [] for case in self._to_list(self.activation): if case not in act_map: logger.warning(f"Unknown activation: {case}, skipping") continue layers.append(act_map[case]) return layers if layers else [nn.Identity] @staticmethod def _get_layer_name(layer: Optional[type]) -> str: """Get readable name for a layer type.""" if layer is None: return "none" return layer.__name__.lower() def _create_mlp( self, hidden_features: List[int], dropout: float, norm_layer: Optional[type], activation: type, ) -> nn.Sequential: """Create a single MLP with specified configuration. Args: hidden_features: List of hidden dimensions (empty for linear model) dropout: Dropout rate norm_layer: Normalization layer class (or None) activation: Activation layer class Returns: Sequential module containing the MLP """ layers = [] # Handle linear model (no hidden layers) if len(hidden_features) == 0: logger.trace("Creating linear model (no hidden layers)") layers.append(nn.Linear(self.in_features, self.out_features)) return nn.Sequential(*layers) # Build hidden layers in_dim = self.in_features for i, hidden_dim in enumerate(hidden_features): layers.append(nn.Linear(in_dim, hidden_dim)) if norm_layer is not None: layers.append(norm_layer(hidden_dim)) layers.append(activation()) layers.append(nn.Dropout(dropout)) in_dim = hidden_dim # Output layer layers.append(nn.Linear(in_dim, self.out_features)) logger.trace(f"Created MLP with {len(hidden_features)} hidden layers") return nn.Sequential(*layers) def _register_lr_scale_hook( self, module: nn.Module, lr_scale: float, weight_decay: float ) -> None: """Register learning rate scaling and weight decay for a module. Note: This is a placeholder - implement based on your training framework. """ # Store as module attributes for optimizer to access module.lr_scale = lr_scale module.weight_decay = weight_decay logger.trace(f"Registered lr_scale={lr_scale}, weight_decay={weight_decay}")
[docs] def forward( self, x: torch.Tensor, y: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: """Forward pass through all MLP variants. Args: x: Input tensor of shape (batch_size, in_features) y: Optional target tensor for loss computation Returns: Dictionary with predictions and losses for each variant Format: {'pred/{variant_id}': tensor, 'loss/{variant_id}': tensor} """ output = {} logger.debug(f"Forward pass with input shape: {x.shape}") for variant_id, mlp in self.mlp.items(): # Get prediction pred = mlp(x) output[f"pred/{variant_id}"] = pred # Compute loss if targets provided if y is not None: loss = self.loss_fn(pred, y) output[f"loss/{variant_id}"] = loss logger.trace(f"{variant_id} loss: {loss.item():.4f}") logger.debug(f"Computed outputs for {len(self.mlp)} variants") return output
[docs] def keys(self) -> List[str]: """Get list of all MLP variant names. Returns: List of variant IDs (strings) Example: >>> model = AutoTuneMLP( ... 128, 10, [[256], [512]], "clf", nn.CrossEntropyLoss() ... ) >>> model.keys() ['clf_arch0_256_none_relu_drop0_lr1_wd0', 'clf_arch1_512_none_relu_drop0_lr1_wd0'] """ return list(self.mlp.keys())
[docs] def get_variant(self, key: str) -> nn.Module: """Get a specific MLP variant by key. Args: key: Variant ID Returns: The MLP module Raises: KeyError: If key doesn't exist """ if key not in self.mlp: available = self.keys() logger.error(f"Variant '{key}' not found. Available: {available[:5]}...") raise KeyError(f"Variant '{key}' not found") return self.mlp[key]
[docs] def get_best_variant( self, metric_dict: Dict[str, float], lower_is_better: bool = True ) -> str: """Get the best performing variant based on metrics. Args: metric_dict: Dictionary mapping variant_id to metric values lower_is_better: If True, lower metric is better (e.g., loss). If False, higher is better (e.g., accuracy) Returns: ID of the best performing variant """ if lower_is_better: best_variant = min(metric_dict, key=metric_dict.get) else: best_variant = max(metric_dict, key=metric_dict.get) best_score = metric_dict[best_variant] logger.info(f"Best variant: {best_variant} with score: {best_score:.4f}") return best_variant
[docs] def num_variants(self) -> int: """Get the number of MLP variants.""" return len(self.mlp)
def __len__(self) -> int: """Get the number of MLP variants.""" return len(self.mlp) def __repr__(self) -> str: return ( f"AutoTuneMLP(name={self.name}, variants={len(self.mlp)}, " f"in_features={self.in_features}, out_features={self.out_features}, " f"architectures={len(self.hidden_features)})" )