Source code for stable_pretraining.callbacks.utils

from functools import partial
from typing import Optional, Union

import torch
import torchmetrics
from lightning.pytorch import Callback, LightningModule
from loguru import logger as logging
import types
from ..optim import create_optimizer, create_scheduler, LARS


class TrainableCallback(Callback):
    """Base callback class with optimizer and scheduler management.

    This base class handles the common logic for callbacks that need their own
    optimizer and scheduler, including automatic inheritance from the main module's
    configuration when not explicitly specified.

    Subclasses should:
    1. Call super().__init__() with appropriate parameters
    2. Store their module configuration in self._module_config
    3. Override configure_model() to create their specific module
    4. Access their module via self.module property after setup
    """

    def __init__(
        self,
        module: LightningModule,
        name: str,
        optimizer: Optional[Union[str, dict, partial, torch.optim.Optimizer]] = None,
        scheduler: Optional[
            Union[str, dict, partial, torch.optim.lr_scheduler.LRScheduler]
        ] = None,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: float = None,
        gradient_clip_algorithm: str = "norm",
    ):
        """Initialize base callback with optimizer/scheduler configuration.

        Args:
            module: spt.Module.
            name: Unique identifier for this callback instance.
            optimizer: Optimizer configuration. If None, uses default LARS.
            scheduler: Scheduler configuration. If None, uses default ConstantLR.
            accumulate_grad_batches: Number of batches to accumulate gradients.
            gradient_clip_val: Value to clip the gradient (default None).
            gradient_clip_algorithm: Algorithm to clip the gradient (default `norm`).
        """
        super().__init__()
        self.name = name
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm

        # Store configurations
        self._optimizer_config = optimizer
        self._scheduler_config = scheduler
        self._pl_module = module
        self.wrap_configure_model(module)
        self.wrap_configure_optimizers(module)

    def wrap_configure_model(self, pl_module):
        fn = pl_module.configure_model

        def new_configure_model(self, callback=self, fn=fn):
            # Initialize module
            fn()
            module = callback.configure_model(self)
            # Store module in pl_module.callbacks_modules
            logging.info(f"{callback.name}: Storing module in callbacks_modules")
            self.callbacks_modules[callback.name] = module
            logging.info(f"{callback.name}: Setting up metrics")
            assert callback.name not in self.callbacks_metrics
            self.callbacks_metrics[callback.name] = format_metrics_as_dict(
                callback.metrics
            )

        # Bind the new method to the instance
        logging.info(f"{self.name}: We are wrapping up your `configure_optimizers`!")
        pl_module.configure_model = types.MethodType(new_configure_model, pl_module)

    def configure_model(self, pl_module: LightningModule) -> torch.nn.Module:
        """Initialize the module for this callback.

        Subclasses must override this method to create their specific module.

        Args:
            pl_module: The Lightning module being trained.

        Returns:
            The initialized module.
        """
        raise NotImplementedError("Subclasses must implement configure_model")

    def wrap_configure_optimizers(self, pl_module):
        fn = pl_module.configure_optimizers

        def new_configure_optimizers(self, callback=self, fn=fn):
            outputs = fn()
            if outputs is None:
                optimizers = []
                schedulers = []
            else:
                optimizers, schedulers = outputs
            # assert callback.name not in self._optimizer_name_to_index
            assert callback.name not in self._optimizer_frequencies
            # assert callback.name not in self._optimizer_names
            assert callback.name not in self._optimizer_gradient_clip_val
            assert callback.name not in self._optimizer_gradient_clip_algorithm
            assert len(optimizers) not in self._optimizer_index_to_name
            self._optimizer_index_to_name[len(optimizers)] = callback.name
            # self._optimizer_name_to_index[callback.name] = len(self._optimizer_names)
            # self._optimizer_names.append(callback.name)
            self._optimizer_frequencies[callback.name] = (
                callback.accumulate_grad_batches
            )
            self._optimizer_gradient_clip_val[callback.name] = (
                callback.gradient_clip_val
            )
            self._optimizer_gradient_clip_algorithm[callback.name] = (
                callback.gradient_clip_algorithm
            )
            optimizers.append(callback.setup_optimizer(self))
            schedulers.append(callback.setup_scheduler(optimizers[-1], self))
            return optimizers, schedulers

        # Bind the new method to the instance
        logging.info(f"{self.name}: We are wrapping up your `configure_optimizers`!")
        pl_module.configure_optimizers = types.MethodType(
            new_configure_optimizers, pl_module
        )

    def setup_optimizer(self, pl_module: LightningModule) -> None:
        """Initialize optimizer with default LARS if not specified."""
        if self._optimizer_config is None:
            # Use default LARS optimizer for SSL linear probes
            logging.info(f"{self.name}: No optimizer given, using default LARS")
            return LARS(
                self.module.parameters(),
                lr=0.1,
                clip_lr=True,
                eta=0.02,
                exclude_bias_n_norm=True,
                weight_decay=0,
            )
        # Use explicitly provided optimizer config
        logging.info(f"{self.name}: Use explicitly provided optimizer")
        return create_optimizer(self.module.parameters(), self._optimizer_config)

    def setup_scheduler(self, optimizer, pl_module: LightningModule) -> None:
        """Initialize scheduler with default ConstantLR if not specified."""
        if self._scheduler_config is None:
            # Use default ConstantLR scheduler
            logging.info(f"{self.name}: No scheduler given, using default ConstantLR")
            return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
        logging.info(f"{self.name}: Use explicitly provided scheduler")
        return create_scheduler(optimizer, self._scheduler_config, module=pl_module)

    @property
    def module(self):
        """Access module from pl_module.callbacks_modules.

        This property is only accessible after setup() has been called.
        The module is stored centrally in pl_module.callbacks_modules
        to avoid duplication in checkpoints.
        """
        if self._pl_module is None:
            raise AttributeError(
                f"{self.name}: module not accessible before setup(). "
                "The module is initialized during the setup phase."
            )
        return self._pl_module.callbacks_modules[self.name]

    @property
    def state_key(self) -> str:
        """Unique identifier for this callback's state during checkpointing."""
        return f"{self.__class__.__name__}[name={self.name}]"


[docs] class EarlyStopping(torch.nn.Module): """Early stopping mechanism with support for metric milestones and patience. This module provides flexible early stopping capabilities that can halt training based on metric performance. It supports both milestone-based stopping (stop if metric doesn't reach target by specific epochs) and patience-based stopping (stop if metric doesn't improve for N epochs). Args: mode: Optimization direction - 'min' for metrics to minimize (e.g., loss), 'max' for metrics to maximize (e.g., accuracy). milestones: Dict mapping epoch numbers to target metric values. Training stops if targets are not met at specified epochs. metric_name: Name of the metric to monitor if metric is a dict. patience: Number of epochs with no improvement before stopping. Example: >>> early_stop = EarlyStopping(mode="max", milestones={10: 0.8, 20: 0.9}) >>> # Stops if accuracy < 0.8 at epoch 10 or < 0.9 at epoch 20 """ def __init__( self, mode: str = "min", milestones: dict[int, float] = None, metric_name: str = None, patience: int = 10, ): super().__init__() self.mode = mode self.milestones = milestones or {} self.metric_name = metric_name self.patience = patience self.register_buffer("history", torch.zeros(patience)) def should_stop(self, metric, step): if self.metric_name is None: assert type(metric) is not dict else: assert self.metric_name in metric metric = metric[self.metric_name] if step in self.milestones: if self.mode == "min": return metric > self.milestones[step] elif self.mode == "max": return metric < self.milestones[step] return False
def format_metrics_as_dict(metrics): """Formats various metric input formats into a standardized dictionary structure. This utility function handles multiple input formats for metrics and converts them into a consistent ModuleDict structure with separate train and validation metrics. This standardization simplifies metric handling across callbacks. Args: metrics: Can be: - None: Returns empty train and val dicts - Single torchmetrics.Metric: Applied to validation only - Dict with 'train' and 'val' keys: Separated accordingly - Dict of metrics: All applied to validation - List/tuple of metrics: All applied to validation Returns: ModuleDict with '_train' and '_val' keys, each containing metric ModuleDicts. Raises: ValueError: If metrics format is invalid or contains non-torchmetric objects. """ # Handle OmegaConf types from omegaconf import ListConfig, DictConfig if isinstance(metrics, (ListConfig, DictConfig)): import omegaconf metrics = omegaconf.OmegaConf.to_container(metrics, resolve=True) if metrics is None: train = {} eval = {} elif isinstance(metrics, torchmetrics.Metric): train = {} eval = torch.nn.ModuleDict({metrics.__class__.__name__: metrics}) elif type(metrics) is dict and set(metrics.keys()) == set(["train", "val"]): if type(metrics["train"]) in [list, tuple]: train = {} for m in metrics["train"]: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") train[m.__class__.__name__] = m else: train = metrics["train"] if type(metrics["val"]) in [list, tuple]: eval = {} for m in metrics["val"]: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") eval[m.__class__.__name__] = m else: eval = metrics["val"] elif type(metrics) is dict: train = {} for k, v in metrics.items(): assert type(k) is str assert isinstance(v, torchmetrics.Metric) eval = metrics elif type(metrics) in [list, tuple]: train = {} eval = {} for m in metrics: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") eval[m.__class__.__name__] = m else: raise ValueError( "metrics can only be a torchmetric of list/tuple of torchmetrics" ) return torch.nn.ModuleDict( { "_train": torch.nn.ModuleDict(train), "_val": torch.nn.ModuleDict(eval), } )