Source code for stable_pretraining.optim.lr_scheduler

import math

import numpy as np
import torch
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    LambdaLR,
    LinearLR,
    MultiStepLR,
    SequentialLR,
    _LRScheduler,
)
import inspect
from functools import partial
from loguru import logger as logging
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from typing import Any, Union


# Default parameter factories for common schedulers (both torch and custom)
# These callables receive the calling module (for trainer context) and optimizer
DEFAULT_SCHEDULER_FACTORIES = {
    # torch schedulers
    "CosineAnnealingLR": lambda module, opt: {
        "T_max": getattr(module.trainer, "estimated_stepping_batches", None),
    },
    "OneCycleLR": lambda module, opt: {
        "max_lr": opt.param_groups[0]["lr"],
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "pct_start": min(10 / getattr(module.trainer, "max_epochs", 1), 0.01),
    },
    "StepLR": lambda module, opt: {"step_size": 30, "gamma": 0.1},
    "ExponentialLR": lambda module, opt: {"gamma": 0.9},
    "ReduceLROnPlateau": lambda module, opt: {
        "mode": "min",
        "patience": 10,
        "factor": 0.1,
    },
    "LinearLR": lambda module, opt: {},
    "ConstantLR": lambda module, opt: {},
    # custom schedulers (defined below)
    "LinearWarmup": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.01,
        "peak_step": max(
            1, int(0.01 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupCosineAnnealing": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.01,
        "end_lr": 0.0,
        "peak_step": max(
            1, int(0.01 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupCyclicAnnealing": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.01,
        "peak_step": max(
            1, int(0.1 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupThreeStepsAnnealing": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.001,
        "gamma": 0.3,
        "peak_step": max(
            1, int(0.05 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupCosineAnnealingLR": lambda module, opt: {
        "warmup_steps": max(
            1, int(0.01 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
        "max_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "warmup_start_lr": 0.0,
        "eta_min": 0.0,
    },
}


def _build_default_params(name: str, module, optimizer):
    factory = DEFAULT_SCHEDULER_FACTORIES.get(name)
    if factory is None:
        return {}
    params = factory(module, optimizer)
    # Remove None in case trainer context is missing
    return {k: v for k, v in params.items() if v is not None}


def create_scheduler(
    optimizer: torch.optim.Optimizer,
    scheduler_config: Union[str, dict, partial, type],
    module: Any = None,
) -> torch.optim.lr_scheduler.LRScheduler:
    """Create a learning rate scheduler with flexible configuration.

    This function provides a unified way to create schedulers from various configuration formats,
    used by both Module and OnlineProbe for consistency.

    Args:
        optimizer: The optimizer to attach the scheduler to
        scheduler_config: Can be:
            - str: Name of scheduler (e.g., "CosineAnnealingLR")
            - dict: {"type": "CosineAnnealingLR", "T_max": 1000, ...}
            - partial: Pre-configured scheduler (e.g., partial(CosineAnnealingLR, T_max=1000))
            - class: Direct scheduler class (will use smart defaults)
        module: Optional module instance for accessing trainer properties (for smart defaults)

    Returns:
        Configured scheduler instance

    Examples:
        >>> # Simple string (uses smart defaults)
        >>> scheduler = create_scheduler(opt, "CosineAnnealingLR")

        >>> # With custom parameters
        >>> scheduler = create_scheduler(
        ...     opt, {"type": "StepLR", "step_size": 30, "gamma": 0.1}
        ... )

        >>> # Using partial for full control
        >>> from functools import partial
        >>> scheduler = create_scheduler(
        ...     opt, partial(torch.optim.lr_scheduler.ExponentialLR, gamma=0.95)
        ... )
    """
    # partial -> call directly
    # Handle Hydra config objects
    if hasattr(scheduler_config, "_target_"):
        logging.info("\tUser provided a Hydra object, instantiating with optimizer!!")
        return instantiate(scheduler_config, optimizer=optimizer, _convert_="object")
    elif isinstance(scheduler_config, partial):
        # It's a functools.partial (duck-typing), call with optimizer
        logging.info("\tUser provided a partial function, calling with optimizer!!")
        return scheduler_config(optimizer)
    elif callable(scheduler_config):
        # Get the signature of the original function
        signature = inspect.signature(scheduler_config)
        # Count the total parameters in the function
        num_args = len(signature.parameters)

        if num_args == 1:
            logging.info(
                "\tUser provided a callable with one arg, calling with optimizer!!"
            )
            return scheduler_config(optimizer)
        elif num_args == 2:
            logging.info(
                "\tUser provided a callable with two args, calling with optimizer, module!!"
            )
            return scheduler_config(optimizer, module)
        else:
            raise NotImplementedError("Not more than 2 args in your lambda scheduler")
    elif isinstance(scheduler_config, (dict, DictConfig)):
        logging.info("\tUser provided a dict")
        # Convert DictConfig to dict if needed
        if isinstance(scheduler_config, DictConfig):
            cfg = OmegaConf.to_container(scheduler_config, resolve=True)
        else:
            cfg = dict(scheduler_config)
        scheduler_type = cfg.pop("type", "CosineAnnealingLR")
        if type(scheduler_type) is not str:
            raise ValueError(
                "When using a dict specification for scheduler"
                "the value of `type` must be a string! got"
                f"{scheduler_type}"
            )
        params = cfg
    elif isinstance(scheduler_config, str):
        logging.info("\tUser provided a str (name)")
        scheduler_type = scheduler_config
        params = {}
    if hasattr(torch.optim.lr_scheduler, scheduler_type):
        fn = getattr(torch.optim.lr_scheduler, scheduler_type)
    elif scheduler_type in globals():
        fn = globals()[scheduler_type]
    else:
        raise ValueError(
            f"Scheduler '{scheduler_type}' not found in torch.optim.lr_scheduler or stable_pretraining.optim.lr_scheduler."
        )
    # If no params provided, use smart defaults if known
    if not params:
        name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
        try:
            params = _build_default_params(name, module, optimizer)
        except Exception:
            params = {}

    # Instantiate. Works for both torch classes and our function factories.
    return fn(optimizer, **params)


[docs] class CosineDecayer: """Apply cosine decay with multiple cycles for learning rate scheduling. This class implements a cosine decay function with multiple cycles that can be used as a learning rate scheduler. The decay follows a cosine curve with additional cyclic variations. Args: total_steps (int): Total number of training steps. n_cycles (int, optional): Number of cycles in the cosine decay. Defaults to 3. gamma (float, optional): Gamma parameter for cycle amplitude. Defaults to 0.2. Example: >>> decayer = CosineDecayer(total_steps=1000, n_cycles=3) >>> lr_factor = decayer(step=500) """ def __init__(self, total_steps, n_cycles=3, gamma=0.2): self.total_steps = total_steps self.n_cycles = n_cycles def __call__(self, step): """Compute the learning rate factor for the given step. Args: step (int): Current training step. Returns: float: Learning rate multiplier factor. """ alpha = 1 - step / self.total_steps cycle = 1 + np.sin(self.n_cycles * 2 * np.pi * step / self.total_steps) / 2 return alpha * cycle
[docs] def LinearWarmup(optimizer, total_steps, start_factor=0.01, peak_step=0.1): """Create a linear warmup learning rate scheduler. This function creates a linear warmup scheduler that gradually increases the learning rate from a small value to the full learning rate over a specified number of steps. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor. Defaults to 0.01. peak_step (float, optional): Step at which warmup peaks (as fraction of total_steps). Defaults to 0.1. Returns: torch.optim.lr_scheduler.LinearLR: Linear warmup scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmup(optimizer, total_steps=1000, start_factor=0.01) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) return warmup
[docs] def LinearWarmupCosineAnnealing( optimizer, total_steps, start_factor=0.01, end_lr=0.0, peak_step=0.01 ): """Combine linear warmup with cosine annealing decay. This function creates a scheduler that first linearly warms up the learning rate, then applies cosine annealing decay. This is commonly used in self-supervised learning to achieve better convergence. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor for warmup. Defaults to 0.01. end_lr (float, optional): Final learning rate after annealing. Defaults to 0.0. peak_step (float, optional): Step at which warmup ends (as fraction of total_steps). Defaults to 0.01. Returns: torch.optim.lr_scheduler.SequentialLR: Combined warmup and annealing scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupCosineAnnealing(optimizer, total_steps=1000) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) anneal = CosineAnnealingLR(optimizer, T_max=total_steps - peak_step, eta_min=end_lr) scheduler = SequentialLR( optimizer, [warmup, anneal], milestones=[peak_step], ) return scheduler
[docs] def LinearWarmupCyclicAnnealing( optimizer, total_steps, start_factor=0.01, peak_step=0.1 ): """Combine linear warmup with cyclic cosine annealing. This function creates a scheduler that combines linear warmup with cyclic cosine annealing. The cyclic annealing provides multiple learning rate cycles which can help escape local minima during training. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor for warmup. Defaults to 0.01. peak_step (float, optional): Step at which warmup ends (as fraction of total_steps). Defaults to 0.1. Returns: torch.optim.lr_scheduler.SequentialLR: Combined warmup and cyclic annealing scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupCyclicAnnealing(optimizer, total_steps=1000) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) decay = LambdaLR(optimizer, CosineDecayer(total_steps - peak_step)) scheduler = SequentialLR( optimizer, [warmup, decay], milestones=[peak_step], ) return scheduler
[docs] def LinearWarmupThreeStepsAnnealing( optimizer, total_steps, start_factor=0.001, gamma=0.3, peak_step=0.05 ): """Combine linear warmup with a three-step learning rate annealing. This function creates a scheduler that combines linear warmup with a three-step annealing schedule. The annealing reduces the learning rate at three predefined milestones, which can help with fine-tuning and convergence. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor for warmup. Defaults to 0.001. gamma (float, optional): Multiplicative factor for learning rate reduction. Defaults to 0.3. peak_step (float, optional): Step at which warmup ends (as fraction of total_steps). Defaults to 0.05. Returns: torch.optim.lr_scheduler.SequentialLR: Combined warmup and three-step annealing scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupThreeStepsAnnealing(optimizer, total_steps=1000) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) anneal = MultiStepLR( optimizer, milestones=[ (total_steps - peak_step) * 0.4, (total_steps - peak_step) * 0.6, (total_steps - peak_step) * 0.8, ], gamma=gamma, ) scheduler = SequentialLR( optimizer, [warmup, anneal], milestones=[peak_step], ) return scheduler
class LinearWarmupCosineAnnealingLR(_LRScheduler): """Learning rate scheduler with linear warmup followed by cosine annealing. This scheduler implements a custom learning rate schedule that combines linear warmup with cosine annealing. It provides more control over the warmup and annealing phases compared to the factory function approach. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. warmup_steps (int): Number of steps for linear warmup. max_steps (int): Total number of training steps. warmup_start_lr (float, optional): Starting learning rate for warmup. Defaults to 0.0. eta_min (float, optional): Minimum learning rate after annealing. Defaults to 0.0. last_epoch (int, optional): The index of last epoch. Defaults to -1. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupCosineAnnealingLR( ... optimizer, warmup_steps=100, max_steps=1000 ... ) """ def __init__( self, optimizer, warmup_steps, max_steps, warmup_start_lr=0.0, eta_min=0.0, last_epoch=-1, ): self.warmup_steps = warmup_steps self.max_steps = max_steps self.warmup_start_lr = warmup_start_lr self.eta_min = eta_min super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) def get_lr(self): """Compute the learning rate for the current epoch. Returns: list: List of learning rates for each parameter group. """ if self.last_epoch < self.warmup_steps: return [ ( self.warmup_start_lr + (base_lr - self.warmup_start_lr) * self.last_epoch / self.warmup_steps ) for base_lr in self.base_lrs ] else: return [ self.eta_min + (base_lr - self.eta_min) * ( 1 + math.cos( math.pi * (self.last_epoch - self.warmup_steps) / (self.max_steps - self.warmup_steps) ) ) / 2 for base_lr in self.base_lrs ]