OnlineWriter#

class stable_pretraining.callbacks.OnlineWriter(names: str, path: str | Path, during: str | list[str], every_k_epochs: int = -1, save_last_epoch: bool = False, save_sanity_check: bool = False, all_gather: bool = True)[source]#

Bases: Callback

Writes specified batch data to disk during training and validation.

This callback enables selective saving of batch data (e.g., features, predictions, embeddings) to disk at specified intervals during training. It’s useful for debugging, visualization, and analysis of model behavior during training.

Features: - Flexible saving schedule (every k epochs, last epoch, sanity check) - Support for distributed training with optional all_gather - Automatic directory creation - Configurable for different training phases (train, val, test)

Parameters:
  • names – Name(s) of the batch keys to save. Can be string or list of strings.

  • path – Directory path where files will be saved.

  • during – Training phase(s) when to save (‘train’, ‘val’, ‘test’, or list).

  • every_k_epochs – Save every k epochs. -1 means every epoch.

  • save_last_epoch – Whether to save on the last training epoch.

  • save_sanity_check – Whether to save during sanity check phase.

  • all_gather – Whether to gather data across all distributed processes.

Files are saved with naming pattern: {phase}_{name}_epoch{epoch}_batch{batch}.pt

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#

Called when the predict batch ends.

on_sanity_check_end(trainer, pl_module)[source]#

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)[source]#

Called when the validation sanity check starts.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#

Called when the test batch ends.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#

Called when the validation batch ends.