OnlineProbe#

class stable_pretraining.callbacks.OnlineProbe(name: str, input: str, target: str, probe: Module, loss_fn: callable, optimizer: str | dict | partial | Optimizer | None = None, scheduler: str | dict | partial | LRScheduler | None = None, accumulate_grad_batches: int = 1, metrics: dict | tuple | list | Metric | None = None, early_stopping: EarlyStopping | None = None)[source]#

Bases: TrainableCallback

Online probe for evaluating learned representations during self-supervised training.

This callback implements the standard linear evaluation protocol by training a probe (typically a linear classifier) on top of frozen features from the main model. The probe is trained simultaneously with the main model but maintains its own optimizer, scheduler, and training loop. This allows monitoring representation quality throughout training without modifying the base model.

Key features: - Automatic gradient detachment to prevent probe gradients affecting the main model - Independent optimizer and scheduler management - Support for gradient accumulation - Mixed precision training compatibility through automatic dtype conversion - Built-in early stopping support - Metric tracking and logging

Parameters:
  • name – Unique identifier for this probe instance. Used for logging and storing metrics/modules.

  • input – Key in batch dict or outputs dict containing input features to probe.

  • target – Key in batch dict containing ground truth target labels.

  • probe – The probe module to train. Can be a nn.Module instance, callable that returns a module, or Hydra config to instantiate.

  • loss_fn – Loss function for probe training (e.g., nn.CrossEntropyLoss()).

  • optimizer – Optimizer configuration for the probe. Can be: - str: optimizer name (e.g., “AdamW”, “SGD”, “LARS”) - dict: {“type”: “AdamW”, “lr”: 1e-3, …} - partial: pre-configured optimizer factory - optimizer instance or callable - None: inherits from main Module’s optimizer config (default)

  • scheduler – Learning rate scheduler configuration. Can be: - str: scheduler name (e.g., “CosineAnnealingLR”, “StepLR”) - dict: {“type”: “CosineAnnealingLR”, “T_max”: 1000, …} - partial: pre-configured scheduler factory - scheduler instance or callable - None: inherits from main Module’s scheduler config (default)

  • accumulate_grad_batches – Number of batches to accumulate gradients before optimizer step. Default is 1 (no accumulation).

  • metrics – Metrics to track during training/validation. Can be dict, list, tuple, or single metric instance.

  • early_stopping – Early stopping configuration to halt training if validation metric stops improving.

Note

  • The probe module is stored in pl_module._callbacks_modules[name]

  • Metrics are stored in pl_module._callbacks_metrics[name]

  • Predictions are stored in batch dict with key ‘{name}_preds’

  • Loss is logged as ‘train/{name}_loss’

  • Metrics are logged with prefix ‘train/{name}_’ and ‘eval/{name}_’

forward_hook_fn(pl_module, args, outputs) None[source]#

Perform probe training step.

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, dataloader_idx: int = 0) None[source]#

Compute probe predictions during validation.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Handle early stopping if configured.

property probe_module#

Alias for self.module for backward compatibility.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Initialize optimizer, scheduler, and metrics.