Source code for stable_pretraining.callbacks.writer

from pathlib import Path
from typing import Union

import torch
from lightning.pytorch import Callback, LightningModule
from loguru import logger as logging


[docs] class OnlineWriter(Callback): """Writes specified batch data to disk during training and validation. This callback enables selective saving of batch data (e.g., features, predictions, embeddings) to disk at specified intervals during training. It's useful for debugging, visualization, and analysis of model behavior during training. Features: - Flexible saving schedule (every k epochs, last epoch, sanity check) - Support for distributed training with optional all_gather - Automatic directory creation - Configurable for different training phases (train, val, test) Args: names: Name(s) of the batch keys to save. Can be string or list of strings. path: Directory path where files will be saved. during: Training phase(s) when to save ('train', 'val', 'test', or list). every_k_epochs: Save every k epochs. -1 means every epoch. save_last_epoch: Whether to save on the last training epoch. save_sanity_check: Whether to save during sanity check phase. all_gather: Whether to gather data across all distributed processes. Files are saved with naming pattern: {phase}_{name}_epoch{epoch}_batch{batch}.pt """ def __init__( self, names: str, path: Union[str, Path], during: Union[str, list[str]], every_k_epochs: int = -1, save_last_epoch: bool = False, save_sanity_check: bool = False, all_gather: bool = True, ) -> None: super().__init__() logging.info("Setting up OnlineWriter callback") logging.info(f"\t- {names=}") logging.info(f"\t- {path=}") logging.info(f"\t- {during=}") logging.info(f"\t- {every_k_epochs=}") logging.info(f"\t- {save_last_epoch=}") logging.info(f"\t- {save_sanity_check=}") logging.info(f"\t- {all_gather=}") path = Path(path) if type(names) is str: names = [names] if type(during) is str: during = [during] self.names = names self.path = path self.during = during self.every_k_epochs = every_k_epochs self.save_last_epoch = save_last_epoch self.save_sanity_check = save_sanity_check self.is_sanity_check = False self.all_gather = all_gather if not path.is_dir(): logging.warning(f"{path=} does not exist, creating it!") path.mkdir(parents=True, exist_ok=False)
[docs] def on_sanity_check_start(self, trainer, pl_module): self.is_sanity_check = True if not self.save_sanity_check: logging.warning("OnlineWriter: skipping sanity check writing")
[docs] def on_sanity_check_end(self, trainer, pl_module): self.is_sanity_check = False
def is_writing_epoch(self, pl_module): current_epoch = pl_module.current_epoch max_epochs = pl_module.trainer.max_epochs save_every_epoch = self.every_k_epochs == -1 save_k_epoch = ( current_epoch % self.every_k_epochs == 0 if self.every_k_epochs != 0 else False ) is_last_epoch = current_epoch == max_epochs - 1 save_last_epoch = self.save_last_epoch and is_last_epoch return any([save_every_epoch, save_k_epoch, save_last_epoch]) def write_at_phase( self, pl_module, phase_name, outputs, batch_idx, ): if self.is_sanity_check and not self.save_sanity_check: return if not self.is_writing_epoch(pl_module) or phase_name not in self.during: return file_info = { "epoch": pl_module.current_epoch, "batch": batch_idx, } if not self.all_gather: file_info["device"] = pl_module.local_rank file_info = "_".join(f"{k}={v}" for k, v in file_info.items()) filename = f"{phase_name}_{file_info}.pt" self.dump(pl_module, outputs, filename)
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.write_at_phase(pl_module, "train", outputs, batch_idx)
[docs] def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.write_at_phase(pl_module, "predict", outputs, batch_idx)
[docs] def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.write_at_phase(pl_module, "test", outputs, batch_idx)
[docs] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.write_at_phase(pl_module, "validation", outputs, batch_idx)
def dump(self, pl_module: LightningModule, outputs: dict, filename: str): to_save = {} for name in self.names: if name not in outputs: msg = ( f"Asking to write {name} but not present " f"in current batch {list(outputs.keys())}" ) logging.error(msg) raise ValueError(msg) data = outputs[name] if self.all_gather: to_save[name] = pl_module.all_gather(data).cpu() else: to_save[name] = data.cpu() if not self.all_gather or pl_module.local_rank == 0: torch.save(to_save, self.path / filename)