Source code for stable_pretraining.optim.lr_scheduler

import math

import numpy as np
import torch
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    LambdaLR,
    LinearLR,
    MultiStepLR,
    SequentialLR,
    _LRScheduler,
)

# Default parameter factories for common schedulers (both torch and custom)
# These callables receive the calling module (for trainer context) and optimizer
DEFAULT_SCHEDULER_FACTORIES = {
    # torch schedulers
    "CosineAnnealingLR": lambda module, opt: {
        "T_max": getattr(module.trainer, "estimated_stepping_batches", None),
    },
    "OneCycleLR": lambda module, opt: {
        "max_lr": opt.param_groups[0]["lr"],
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "pct_start": min(10 / getattr(module.trainer, "max_epochs", 1), 0.01),
    },
    "StepLR": lambda module, opt: {"step_size": 30, "gamma": 0.1},
    "ExponentialLR": lambda module, opt: {"gamma": 0.9},
    "ReduceLROnPlateau": lambda module, opt: {
        "mode": "min",
        "patience": 10,
        "factor": 0.1,
    },
    "LinearLR": lambda module, opt: {},
    "ConstantLR": lambda module, opt: {},
    # custom schedulers (defined below)
    "LinearWarmup": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.01,
        "peak_step": max(
            1, int(0.01 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupCosineAnnealing": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.01,
        "end_lr": 0.0,
        "peak_step": max(
            1, int(0.01 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupCyclicAnnealing": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.01,
        "peak_step": max(
            1, int(0.1 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupThreeStepsAnnealing": lambda module, opt: {
        "total_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "start_factor": 0.001,
        "gamma": 0.3,
        "peak_step": max(
            1, int(0.05 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
    },
    "LinearWarmupCosineAnnealingLR": lambda module, opt: {
        "warmup_steps": max(
            1, int(0.01 * getattr(module.trainer, "estimated_stepping_batches", 1))
        ),
        "max_steps": getattr(module.trainer, "estimated_stepping_batches", None),
        "warmup_start_lr": 0.0,
        "eta_min": 0.0,
    },
}


def _resolve_scheduler_callable(name_or_class):
    """Resolve a scheduler by name from torch or this module.

    Accepts a string name, class, or callable. Returns a callable to construct the scheduler.
    """
    if not isinstance(name_or_class, str):
        return name_or_class

    # Try torch.optim.lr_scheduler first
    if hasattr(torch.optim.lr_scheduler, name_or_class):
        return getattr(torch.optim.lr_scheduler, name_or_class)

    # Then try this module (custom functions/classes)
    if name_or_class in globals():
        return globals()[name_or_class]

    raise ValueError(
        f"Scheduler '{name_or_class}' not found in torch.optim.lr_scheduler or stable_pretraining.optim.lr_scheduler."
    )


def _build_default_params(name: str, module, optimizer):
    factory = DEFAULT_SCHEDULER_FACTORIES.get(name)
    if factory is None:
        return {}
    params = factory(module, optimizer)
    # Remove None in case trainer context is missing
    return {k: v for k, v in params.items() if v is not None}


def create_scheduler(optimizer, scheduler_config, module=None):
    """Create a learning rate scheduler instance from a flexible config.

    Args:
        optimizer: torch optimizer
        scheduler_config: str | dict | partial | Class | callable
        module: optional calling module for defaults (expects module.trainer)

    Returns:
        Instantiated scheduler
    """
    # partial -> call directly
    if isinstance(scheduler_config, type(lambda: None)) and hasattr(
        scheduler_config, "func"
    ):
        # It's a functools.partial (duck-typing), call with optimizer
        return scheduler_config(optimizer)

    # dict -> pop type + params
    if isinstance(scheduler_config, dict):
        cfg = dict(scheduler_config)
        scheduler_type = cfg.pop("type", "CosineAnnealingLR")
        params = cfg
    else:
        scheduler_type = scheduler_config
        params = {}

    scheduler_ctor = _resolve_scheduler_callable(scheduler_type)

    # If no params provided, use smart defaults if known
    if not params:
        name = (
            scheduler_ctor.__name__
            if hasattr(scheduler_ctor, "__name__")
            else str(scheduler_ctor)
        )
        try:
            params = _build_default_params(name, module, optimizer)
        except Exception:
            params = {}

    # Instantiate. Works for both torch classes and our function factories.
    return scheduler_ctor(optimizer, **params)


[docs] class CosineDecayer: """Apply cosine decay with multiple cycles for learning rate scheduling. This class implements a cosine decay function with multiple cycles that can be used as a learning rate scheduler. The decay follows a cosine curve with additional cyclic variations. Args: total_steps (int): Total number of training steps. n_cycles (int, optional): Number of cycles in the cosine decay. Defaults to 3. gamma (float, optional): Gamma parameter for cycle amplitude. Defaults to 0.2. Example: >>> decayer = CosineDecayer(total_steps=1000, n_cycles=3) >>> lr_factor = decayer(step=500) """ def __init__(self, total_steps, n_cycles=3, gamma=0.2): self.total_steps = total_steps self.n_cycles = n_cycles def __call__(self, step): """Compute the learning rate factor for the given step. Args: step (int): Current training step. Returns: float: Learning rate multiplier factor. """ alpha = 1 - step / self.total_steps cycle = 1 + np.sin(self.n_cycles * 2 * np.pi * step / self.total_steps) / 2 return alpha * cycle
[docs] def LinearWarmup(optimizer, total_steps, start_factor=0.01, peak_step=0.1): """Create a linear warmup learning rate scheduler. This function creates a linear warmup scheduler that gradually increases the learning rate from a small value to the full learning rate over a specified number of steps. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor. Defaults to 0.01. peak_step (float, optional): Step at which warmup peaks (as fraction of total_steps). Defaults to 0.1. Returns: torch.optim.lr_scheduler.LinearLR: Linear warmup scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmup(optimizer, total_steps=1000, start_factor=0.01) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) return warmup
[docs] def LinearWarmupCosineAnnealing( optimizer, total_steps, start_factor=0.01, end_lr=0.0, peak_step=0.01 ): """Combine linear warmup with cosine annealing decay. This function creates a scheduler that first linearly warms up the learning rate, then applies cosine annealing decay. This is commonly used in self-supervised learning to achieve better convergence. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor for warmup. Defaults to 0.01. end_lr (float, optional): Final learning rate after annealing. Defaults to 0.0. peak_step (float, optional): Step at which warmup ends (as fraction of total_steps). Defaults to 0.01. Returns: torch.optim.lr_scheduler.SequentialLR: Combined warmup and annealing scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupCosineAnnealing(optimizer, total_steps=1000) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) anneal = CosineAnnealingLR(optimizer, T_max=total_steps - peak_step, eta_min=end_lr) scheduler = SequentialLR( optimizer, [warmup, anneal], milestones=[peak_step], ) return scheduler
[docs] def LinearWarmupCyclicAnnealing( optimizer, total_steps, start_factor=0.01, peak_step=0.1 ): """Combine linear warmup with cyclic cosine annealing. This function creates a scheduler that combines linear warmup with cyclic cosine annealing. The cyclic annealing provides multiple learning rate cycles which can help escape local minima during training. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor for warmup. Defaults to 0.01. peak_step (float, optional): Step at which warmup ends (as fraction of total_steps). Defaults to 0.1. Returns: torch.optim.lr_scheduler.SequentialLR: Combined warmup and cyclic annealing scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupCyclicAnnealing(optimizer, total_steps=1000) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) decay = LambdaLR(optimizer, CosineDecayer(total_steps - peak_step)) scheduler = SequentialLR( optimizer, [warmup, decay], milestones=[peak_step], ) return scheduler
[docs] def LinearWarmupThreeStepsAnnealing( optimizer, total_steps, start_factor=0.001, gamma=0.3, peak_step=0.05 ): """Combine linear warmup with a three-step learning rate annealing. This function creates a scheduler that combines linear warmup with a three-step annealing schedule. The annealing reduces the learning rate at three predefined milestones, which can help with fine-tuning and convergence. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. total_steps (int): Total number of training steps. start_factor (float, optional): Initial learning rate factor for warmup. Defaults to 0.001. gamma (float, optional): Multiplicative factor for learning rate reduction. Defaults to 0.3. peak_step (float, optional): Step at which warmup ends (as fraction of total_steps). Defaults to 0.05. Returns: torch.optim.lr_scheduler.SequentialLR: Combined warmup and three-step annealing scheduler. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupThreeStepsAnnealing(optimizer, total_steps=1000) """ if peak_step < 1: peak_step = int(peak_step * total_steps) warmup = LinearLR(optimizer, start_factor, total_iters=peak_step) anneal = MultiStepLR( optimizer, milestones=[ (total_steps - peak_step) * 0.4, (total_steps - peak_step) * 0.6, (total_steps - peak_step) * 0.8, ], gamma=gamma, ) scheduler = SequentialLR( optimizer, [warmup, anneal], milestones=[peak_step], ) return scheduler
class LinearWarmupCosineAnnealingLR(_LRScheduler): """Learning rate scheduler with linear warmup followed by cosine annealing. This scheduler implements a custom learning rate schedule that combines linear warmup with cosine annealing. It provides more control over the warmup and annealing phases compared to the factory function approach. Args: optimizer (torch.optim.Optimizer): The optimizer to schedule. warmup_steps (int): Number of steps for linear warmup. max_steps (int): Total number of training steps. warmup_start_lr (float, optional): Starting learning rate for warmup. Defaults to 0.0. eta_min (float, optional): Minimum learning rate after annealing. Defaults to 0.0. last_epoch (int, optional): The index of last epoch. Defaults to -1. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = LinearWarmupCosineAnnealingLR( ... optimizer, warmup_steps=100, max_steps=1000 ... ) """ def __init__( self, optimizer, warmup_steps, max_steps, warmup_start_lr=0.0, eta_min=0.0, last_epoch=-1, ): self.warmup_steps = warmup_steps self.max_steps = max_steps self.warmup_start_lr = warmup_start_lr self.eta_min = eta_min super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) def get_lr(self): """Compute the learning rate for the current epoch. Returns: list: List of learning rates for each parameter group. """ if self.last_epoch < self.warmup_steps: return [ ( self.warmup_start_lr + (base_lr - self.warmup_start_lr) * self.last_epoch / self.warmup_steps ) for base_lr in self.base_lrs ] else: return [ self.eta_min + (base_lr - self.eta_min) * ( 1 + math.cos( math.pi * (self.last_epoch - self.warmup_steps) / (self.max_steps - self.warmup_steps) ) ) / 2 for base_lr in self.base_lrs ]