Source code for stable_pretraining.callbacks.earlystop
import lightning as pl
from typing import Union
from loguru import logger as logging
import numpy as np
import torchmetrics
[docs]
def to_scalar(x):
if isinstance(x, torchmetrics.Metric):
return x.compute().item()
return x.item() if hasattr(x, "item") else x
[docs]
class EpochMilestones(pl.Callback):
"""PyTorch Lightning callback to stop training if a monitored metric does not meet specified thresholds at given epochs.
This callback allows you to define "milestones"—specific epochs at which a metric must surpass (or fall below) a given value.
If the metric fails to meet the requirement at the milestone epoch, training is stopped early.
Args:
metric_name (str):
The name of the metric to monitor (as logged in `trainer.callback_metrics`).
milestones (dict[int, float]):
A dictionary mapping epoch numbers (int) to required metric values (float).
At each specified epoch, the metric is checked against the corresponding value.
direction (str, optional):
One of "max" or "min".
- "max": Training stops if the metric is less than or equal to the milestone value.
- "min": Training stops if the metric is greater than or equal to the milestone value.
Default is "max".
after_validation (bool, optional):
If True (default), the metric is checked after validation (`on_validation_end`).
If False, the metric is checked after training (`on_training_end`).
Raises:
ValueError: If the specified metric is not found in `trainer.callback_metrics` at the milestone epoch.
Example:
>>> milestones = {10: 0.2, 20: 0.5}
>>> callback = EpochMilestones(
... metric_name="eva/accuracy",
... milestones=milestones,
... direction="max",
... after_validation=True,
... )
>>> trainer = pl.Trainer(callbacks=[callback])
"""
def __init__(
self,
milestones: dict[int, float],
monitor: Union[list[str], str] = None,
contains: str = None,
direction: str = "max",
after_validation: bool = True,
strict: bool = True,
):
if monitor is None and contains is None:
raise ValueError("`monitor` and `contains` can't both be None")
super().__init__()
if type(monitor) is str:
monitor = [monitor]
if type(contains) is str:
contains = [contains]
self.monitor = monitor
self.contains = contains
self.strict = strict
self.milestones = milestones
self.direction = direction
self.after_validation = after_validation
def _check_condition(self, trainer):
# Get the current epoch
metrics = trainer.callback_metrics
epoch = trainer.current_epoch
# Select metrics by exact or substring match
if self.monitor:
matched = {m: metrics.get(m) for m in self.monitor}
else:
matched = {
k: v
for contain in self.contains
for k, v in metrics.items()
if contain in k
}
# Sanity check: verify presence
if trainer.sanity_checking:
logging.info("Sanity checking EpochMilstones...")
logging.info(f"We matched {len(matched)} metrics!")
if not matched:
msg = f"No metrics found for monitor='{self.monitor}' contains='{self.contains}' in callback_metrics: {list(metrics.keys())}"
if self.strict:
raise RuntimeError(msg)
else:
logging.warning(msg)
logging.info(
f"Sanity check passed, congrats! We will use {self.milestones}..."
)
return
if epoch not in self.milestones:
logging.info(f"EpochMilestones: {epoch=} is not in milestones, skipping...")
return
logging.info(f"EpochMilestones: {epoch=} is in milestones, checking condition!")
# Retrieve the metric from the logged metrics
values = list(matched.values())
# Stop training if the metric is not greater than min_value
if self.direction == "max":
final = np.max([to_scalar(x) for x in values])
logging.info(f"EpochMilestones: Maximum value is {final}")
if final < self.milestones[epoch]:
logging.warning(
f"EpochMilestones: Value {final} below threshold"
f" {self.milestones[epoch]}... stopping!"
)
trainer.should_stop = True
else:
logging.warning(
f"EpochMilestones: Value {final} above threshold"
f" {self.milestones[epoch]}... Yayy!"
)
else:
final = np.min([to_scalar(x) for x in values])
logging.info(f"EpochMilestones: Minimum value is {final}")
if final > self.milestones[epoch]:
logging.warning(
f"EpochMilestones: Value {final} above threshold"
f" {self.milestones[epoch]}... stopping!"
)
trainer.should_stop = True
else:
logging.warning(
f"EpochMilestones: Value {final} below threshold"
f" {self.milestones[epoch]}... Yayy!"
)
[docs]
def on_training_epoch_end(self, trainer, pl_module):
if self.after_validation:
return
self._check_condition(trainer)
[docs]
def on_validation_epoch_end(self, trainer, pl_module):
if not self.after_validation:
return
self._check_condition(trainer)