OnlineWriter#
- class stable_pretraining.callbacks.OnlineWriter(names: str, path: str | Path, during: str | list[str], every_k_epochs: int = -1, save_last_epoch: bool = False, save_sanity_check: bool = False, all_gather: bool = True)[source]#
Bases:
Callback
Writes specified batch data to disk during training and validation.
This callback enables selective saving of batch data (e.g., features, predictions, embeddings) to disk at specified intervals during training. It’s useful for debugging, visualization, and analysis of model behavior during training.
Features: - Flexible saving schedule (every k epochs, last epoch, sanity check) - Support for distributed training with optional all_gather - Automatic directory creation - Configurable for different training phases (train, val, test)
- Parameters:
names – Name(s) of the batch keys to save. Can be string or list of strings.
path – Directory path where files will be saved.
during – Training phase(s) when to save (‘train’, ‘val’, ‘test’, or list).
every_k_epochs – Save every k epochs. -1 means every epoch.
save_last_epoch – Whether to save on the last training epoch.
save_sanity_check – Whether to save during sanity check phase.
all_gather – Whether to gather data across all distributed processes.
Files are saved with naming pattern: {phase}_{name}_epoch{epoch}_batch{batch}.pt
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the predict batch ends.
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called when the test batch ends.