Source code for stable_pretraining.callbacks.wd_schedule

import math
from loguru import logger
from lightning.pytorch import Callback, Trainer, LightningModule


[docs] class WeightDecayUpdater(Callback): """PyTorch Lightning Callback to update optimizer's weight decay per batch. - Supports multiple schedules: 'constant', 'linear', 'cosine', 'exponential' - Optionally specify which optimizer param group(s) to update (by index) - Infers total steps from Trainer config (max_steps or max_epochs + dataloader) - Checkpointable: state is saved/restored with Trainer checkpoints - Extensive Loguru logging Args: schedule_type (str): One of 'constant', 'linear', 'cosine', 'exponential' start_value (float): Initial weight decay value end_value (float): Final weight decay value (for non-constant schedules) param_group_indices (list[int] or None): List of param group indices to update. If None, updates all. """ def __init__( self, schedule_type: str = "cosine", start_value: float = 0.01, end_value: float = 0.0, param_group_indices: list = None, opt_idx: int = None, ): super().__init__() self.schedule_type = schedule_type self.start_value = start_value self.end_value = end_value self.param_group_indices = param_group_indices self.total_steps = None # Will be set in on_fit_start self.opt_idx = opt_idx
[docs] def on_fit_start(self, trainer: Trainer, pl_module: LightningModule): # Prefer max_steps if set self.total_steps = ( trainer.estimated_stepping_batches * trainer.accumulate_grad_batches ) logger.info(f"[WeightDecayUpdater] Using total_steps={self.total_steps}")
[docs] def on_before_optimizer_step( self, trainer: Trainer, pl_module: LightningModule, optimizer ): optis = pl_module.optimizers() if self.opt_idx is not None and optimizer != optis[self.opt_idx].optimizer: return step = trainer.global_step // len(optis) accumulate_grad_batches = trainer.accumulate_grad_batches if (step + 1) % accumulate_grad_batches != 0: logger.debug( "[WeightDecayUpdater] Step but accumulating grad, skipping step" ) return new_weight_decay = self._compute_weight_decay(step) indices = ( self.param_group_indices if self.param_group_indices is not None else range(len(optimizer.param_groups)) ) for i in indices: param_group = optimizer.param_groups[i] old_wd = param_group.get("weight_decay", None) param_group["weight_decay"] = new_weight_decay logger.debug( f"[WeightDecayUpdater] Step {step}: param_group {i} weight_decay {old_wd} -> {new_weight_decay}" )
def _compute_weight_decay(self, step: int) -> float: progress = min(step, self.total_steps) / self.total_steps if self.schedule_type == "constant": return self.start_value elif self.schedule_type == "linear": return self.start_value + (self.end_value - self.start_value) * progress elif self.schedule_type == "cosine": return self.end_value + 0.5 * (self.start_value - self.end_value) * ( 1 + math.cos(math.pi * progress) ) elif self.schedule_type == "exponential": # Exponential decay from start_value to end_value gamma = math.log(self.end_value / self.start_value) / self.total_steps return self.start_value * math.exp(gamma * step) else: logger.error( f"[WeightDecayUpdater] Unknown schedule_type: {self.schedule_type}" ) raise ValueError(f"Unknown schedule_type: {self.schedule_type}")
[docs] def state_dict(self): return { "schedule_type": self.schedule_type, "start_value": self.start_value, "end_value": self.end_value, "param_group_indices": self.param_group_indices, "total_steps": self.total_steps, "opt_idx": self.opt_idx, }
[docs] def load_state_dict(self, state_dict): self.schedule_type = state_dict.get("schedule_type", self.schedule_type) self.start_value = state_dict.get("start_value", self.start_value) self.end_value = state_dict.get("end_value", self.end_value) self.opt_idx = state_dict.get("opt_idx", self.opt_idx) self.param_group_indices = state_dict.get( "param_group_indices", self.param_group_indices ) self.total_steps = state_dict.get("total_steps", self.total_steps) logger.info( f"[WeightDecayUpdater] State restored from checkpoint: {state_dict}" )