Source code for stable_ssl.callbacks.utils

import torch
import torchmetrics


[docs] class EarlyStopping(torch.nn.Module): """Early stopping module that can stop training based on metric milestones.""" def __init__( self, mode: str = "min", milestones: dict[int, float] = None, metric_name: str = None, patience: int = 10, ): super().__init__() self.mode = mode self.milestones = milestones or {} self.metric_name = metric_name self.patience = patience self.register_buffer("history", torch.zeros(patience)) def should_stop(self, metric, step): if self.metric_name is None: assert type(metric) is not dict else: assert self.metric_name in metric metric = metric[self.metric_name] if step in self.milestones: if self.mode == "min": return metric > self.milestones[step] elif self.mode == "max": return metric < self.milestones[step] return False
def format_metrics_as_dict(metrics): if metrics is None: train = {} eval = {} elif isinstance(metrics, torchmetrics.Metric): train = {} eval = torch.nn.ModuleDict({metrics.__class__.__name__: metrics}) elif type(metrics) is dict and set(metrics.keys()) == set(["train", "val"]): if type(metrics["train"]) in [list, tuple]: train = {} for m in metrics["train"]: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") train[m.__class__.__name__] = m else: train = metrics["train"] if type(metrics["val"]) in [list, tuple]: eval = {} for m in metrics["val"]: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") eval[m.__class__.__name__] = m else: eval = metrics["eval"] elif type(metrics) is dict: train = {} for k, v in metrics.items(): assert type(k) is str assert isinstance(v, torchmetrics.Metric) eval = metrics elif type(metrics) in [list, tuple]: train = {} for m in metrics: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") eval = {m.__class__.__name__: m for m in metrics} else: raise ValueError( "metrics can only be a torchmetric of list/tuple of torchmetrics" ) print(train, eval) return torch.nn.ModuleDict( { "_train": torch.nn.ModuleDict(train), "_val": torch.nn.ModuleDict(eval), } )