Source code for stable_pretraining.callbacks.probe

from functools import partial
from typing import Dict, Optional, Union

import torch
import torchmetrics
from hydra.utils import instantiate
from lightning.pytorch import LightningModule, Trainer
from loguru import logger as logging

from ..utils import get_data_from_batch_or_outputs

from .utils import EarlyStopping, TrainableCallback, format_metrics_as_dict


[docs] class OnlineProbe(TrainableCallback): """Online probe for evaluating learned representations during self-supervised training. This callback implements the standard linear evaluation protocol by training a probe (typically a linear classifier) on top of frozen features from the main model. The probe is trained simultaneously with the main model but maintains its own optimizer, scheduler, and training loop. This allows monitoring representation quality throughout training without modifying the base model. Key features: - Automatic gradient detachment to prevent probe gradients affecting the main model - Independent optimizer and scheduler management - Support for gradient accumulation - Mixed precision training compatibility through automatic dtype conversion - Built-in early stopping support - Metric tracking and logging Args: name: Unique identifier for this probe instance. Used for logging and storing metrics/modules. input: Key in batch dict or outputs dict containing input features to probe. target: Key in batch dict containing ground truth target labels. probe: The probe module to train. Can be a nn.Module instance, callable that returns a module, or Hydra config to instantiate. loss_fn: Loss function for probe training (e.g., nn.CrossEntropyLoss()). optimizer: Optimizer configuration for the probe. Can be: - str: optimizer name (e.g., "AdamW", "SGD", "LARS") - dict: {"type": "AdamW", "lr": 1e-3, ...} - partial: pre-configured optimizer factory - optimizer instance or callable - None: inherits from main Module's optimizer config (default) scheduler: Learning rate scheduler configuration. Can be: - str: scheduler name (e.g., "CosineAnnealingLR", "StepLR") - dict: {"type": "CosineAnnealingLR", "T_max": 1000, ...} - partial: pre-configured scheduler factory - scheduler instance or callable - None: inherits from main Module's scheduler config (default) accumulate_grad_batches: Number of batches to accumulate gradients before optimizer step. Default is 1 (no accumulation). metrics: Metrics to track during training/validation. Can be dict, list, tuple, or single metric instance. early_stopping: Early stopping configuration to halt training if validation metric stops improving. Note: - The probe module is stored in pl_module._callbacks_modules[name] - Metrics are stored in pl_module._callbacks_metrics[name] - Predictions are stored in batch dict with key '{name}_preds' - Loss is logged as 'train/{name}_loss' - Metrics are logged with prefix 'train/{name}_' and 'eval/{name}_' """ def __init__( self, name: str, input: str, target: str, probe: torch.nn.Module, loss_fn: callable, optimizer: Optional[Union[str, dict, partial, torch.optim.Optimizer]] = None, scheduler: Optional[ Union[str, dict, partial, torch.optim.lr_scheduler.LRScheduler] ] = None, accumulate_grad_batches: int = 1, metrics: Optional[Union[dict, tuple, list, torchmetrics.Metric]] = None, early_stopping: Optional[EarlyStopping] = None, ) -> None: # Initialize base class super().__init__( name=name, optimizer=optimizer, scheduler=scheduler, accumulate_grad_batches=accumulate_grad_batches, ) self.input = input self.target = target self.loss_fn = loss_fn self.early_stopping = early_stopping # Store probe configuration for later initialization self._probe_config = probe # These will be initialized in setup self._train_metrics = None self._val_metrics = None # Format metrics self.metrics_config = metrics logging.info(f"Initialized OnlineProbe callback: {name}") logging.info(f" - Input: {input}") logging.info(f" - Target: {target}") logging.info(f" - Accumulate grad batches: {accumulate_grad_batches}") def _initialize_module(self, pl_module: LightningModule) -> torch.nn.Module: """Initialize the probe module from configuration.""" if isinstance(self._probe_config, torch.nn.Module): probe_module = self._probe_config elif callable(self._probe_config): probe_module = self._probe_config() else: probe_module = instantiate(self._probe_config, _convert_="object") return probe_module
[docs] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: """Initialize optimizer, scheduler, and metrics.""" # Call parent setup for module/optimizer/scheduler super().setup(trainer, pl_module, stage) if stage != "fit": return # Setup metrics logging.info(f"{self.name}: Setting up metrics") if not hasattr(pl_module, "_callbacks_metrics"): logging.info( "attaching a `_callbacks_metrics` to your LightningModule for callbacks" ) pl_module._callbacks_metrics = {} pl_module._callbacks_metrics[self.name] = format_metrics_as_dict( self.metrics_config ) self._train_metrics = pl_module._callbacks_metrics[self.name]["_train"] self._val_metrics = pl_module._callbacks_metrics[self.name]["_val"] pl_module.register_forward_hook(self.forward_hook_fn) logging.info(f"Main module forward hooks: {pl_module._forward_hooks}")
[docs] def forward_hook_fn(self, pl_module, args, outputs) -> None: """Perform probe training step.""" # Extract batch from args tuple (it's the first argument to forward) if isinstance(args, tuple) and len(args) > 0: batch = args[0] else: batch = args if not isinstance(args, tuple) else {} x = get_data_from_batch_or_outputs( self.input, batch, outputs, caller_name=self.name ) y = get_data_from_batch_or_outputs( self.target, batch, outputs, caller_name=self.name ) if x is None or y is None: logging.warning(f"Callback {self.name} missing x or y") return self.module.train() with torch.enable_grad(): x = x.detach() probe_dtype = next(self.module.parameters()).dtype if x.dtype != probe_dtype: x = x.to(probe_dtype) preds = self.module(x) if pl_module.trainer.training: loss = self.loss_fn(preds, y) loss = loss / self.accumulate_grad_batches outputs["loss"] += loss logs = { f"train/{self.name}_loss": loss.item() * self.accumulate_grad_batches } else: logs = {} prediction_key = f"{self.name}_preds" if prediction_key not in batch: outputs[prediction_key] = preds.detach() for metric_name, metric in pl_module._callbacks_metrics[self.name][ "_train" ].items(): metric(preds.detach(), y) logs[f"train/{self.name}_{metric_name}"] = metric pl_module.log_dict(logs, on_step=True, on_epoch=True, sync_dist=True) return outputs
[docs] def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, ) -> None: # Optimizer step using parent class method self.optimizer_step(batch_idx, trainer)
[docs] def on_validation_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Compute probe predictions during validation.""" # Get input and target data x = get_data_from_batch_or_outputs( self.input, batch, outputs, caller_name=self.name ) y = get_data_from_batch_or_outputs( self.target, batch, outputs, caller_name=self.name ) if x is None or y is None: logging.warning( "OnlineProbe callback doesn't have access to its `x` or `y` tensor!" ) return # Ensure probe is in eval mode self.module.eval() # Forward pass without gradients with torch.inference_mode(): # Ensure input has same dtype as probe module # This handles mixed precision training where features might be float16 probe_dtype = next(self.module.parameters()).dtype if x.dtype != probe_dtype: x = x.to(probe_dtype) preds = self.module(x) # Store predictions in batch prediction_key = f"{self.name}_preds" if prediction_key not in batch: batch[prediction_key] = preds # Update metrics and log logs = {} for metric_name, metric in pl_module._callbacks_metrics[self.name][ "_val" ].items(): metric(preds, y) logs[f"eval/{self.name}_{metric_name}"] = metric pl_module.log_dict(logs, on_step=False, on_epoch=True, sync_dist=True)
[docs] def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: """Handle early stopping if configured.""" if self.early_stopping is None: return logging.info(f"{self.name} checking for early stopping condition") # Get the metric value for early stopping metric_name = f"eval/{self.name}_{self.early_stopping.monitor}" if metric_name not in trainer.callback_metrics: logging.warning( f"{self.name}: Early stopping metric {metric_name} not found" ) return current_value = trainer.callback_metrics[metric_name] should_stop = self.early_stopping.should_stop( current_value, trainer.current_epoch ) if should_stop: logging.info( f"{self.name}: Early stopping triggered at epoch {trainer.current_epoch} by {self.name}" ) trainer.should_stop = True
@property def probe_module(self): """Alias for self.module for backward compatibility.""" return self.module