Source code for stable_pretraining.callbacks.cleanup

import os
import glob
import shutil
from typing import List, Tuple, Optional, Sequence
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.trainer.trainer import Trainer
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.utilities.rank_zero import rank_zero_only

try:
    from hydra.core.hydra_config import HydraConfig
except ImportError:
    HydraConfig = None


[docs] def human_size(nbytes: int) -> str: """Convert a file size in bytes to a human-readable string. Args: nbytes (int): File size in bytes. Returns: str: Human-readable file size (e.g., '1.2 MB'). """ for unit in ["B", "KB", "MB", "GB", "TB"]: if nbytes < 1024: return f"{nbytes:.1f} {unit}" nbytes /= 1024 return f"{nbytes:.1f} PB"
def _resolve_hydra_output_dir() -> str: """Resolve the Hydra output directory if available, else fallback to current working directory. Returns: str: Path to the Hydra output directory or current working directory. """ if HydraConfig is not None: try: return HydraConfig.get().runtime.output_dir except Exception: pass return os.getcwd()
[docs] class CleanUpCallback(Callback): """PyTorch Lightning callback to monitor and clean up SLURM and Hydra files during and after training. At each epoch, prints the names and sizes of SLURM and Hydra files. At the end of successful training, deletes those files and prints a summary. Safe for DDP (only rank 0 prints/deletes). Args: slurm_patterns (Optional[Sequence[str]]): Glob patterns for SLURM files to monitor/delete. Defaults to ["slurm-*.out", "slurm-*.err"]. search_paths (Optional[Sequence[str]]): Directories to search for SLURM files. Defaults to [os.getcwd(), $SLURM_SUBMIT_DIR if set]. delete_hydra (bool): Whether to delete Hydra artifacts (.hydra directory and hydra.log). Defaults to True. dry_run (bool): If True, only print what would be deleted, do not actually delete. Defaults to False. Example: ```python from pytorch_lightning import Trainer cleanup_cb = CleanUpCallback() trainer = Trainer(callbacks=[cleanup_cb]) ``` Example (custom patterns, dry run): ```python cleanup_cb = CleanUpCallback( slurm_patterns=["slurm-*.out", "myjob-*.log"], search_paths=["/scratch/logs", "/tmp"], delete_hydra=False, dry_run=True, ) ``` """ def __init__( self, slurm_patterns: Optional[Sequence[str]] = None, search_paths: Optional[Sequence[str]] = None, delete_hydra: bool = True, dry_run: bool = False, ) -> None: self.slurm_patterns: Sequence[str] = slurm_patterns or [ "slurm-*.out", "slurm-*.err", ] self.search_paths: List[str] = ( list(search_paths) if search_paths else [os.getcwd()] ) slurm_submit_dir = os.environ.get("SLURM_SUBMIT_DIR") if slurm_submit_dir and slurm_submit_dir not in self.search_paths: self.search_paths.append(slurm_submit_dir) self.delete_hydra: bool = delete_hydra self.dry_run: bool = dry_run self._exception: bool = False self._files_to_delete: List[Tuple[str, str, int]] = [] def _find_files(self) -> List[Tuple[str, str, int]]: """Find SLURM and Hydra files to monitor/delete. Returns: List[Tuple[str, str, int]]: List of (type, path, size) tuples. """ files: List[Tuple[str, str, int]] = [] # SLURM files for path in self.search_paths: for pattern in self.slurm_patterns: for f in glob.glob(os.path.join(path, pattern)): if os.path.isfile(f): files.append(("SLURM", f, os.path.getsize(f))) # Hydra files hydra_dir = _resolve_hydra_output_dir() hydra_log = os.path.join(hydra_dir, "hydra.log") if os.path.isfile(hydra_log): files.append(("Hydra", hydra_log, os.path.getsize(hydra_log))) hydra_dot_dir = os.path.join(hydra_dir, ".hydra") if os.path.isdir(hydra_dot_dir): total_size = 0 for root, _, fs in os.walk(hydra_dot_dir): for f in fs: fp = os.path.join(root, f) try: total_size += os.path.getsize(fp) except Exception: pass files.append(("Hydra", hydra_dot_dir, total_size)) return files
[docs] @rank_zero_only def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Print SLURM and Hydra file info at the end of each epoch. Args: trainer (Trainer): The PyTorch Lightning Trainer. pl_module (LightningModule): The LightningModule being trained. """ files = self._find_files() self._files_to_delete = files print(f"\n[CleanUpCallback] Epoch {trainer.current_epoch}:") if not files: print(" No SLURM/Hydra files found.") for typ, f, sz in files: print(f" [{typ}] {f} ({human_size(sz)})")
[docs] @rank_zero_only def on_exception( self, trainer: Trainer, pl_module: LightningModule, exception: BaseException ) -> None: """Mark that an exception occurred, so files will not be deleted at the end. Args: trainer (Trainer): The PyTorch Lightning Trainer. pl_module (LightningModule): The LightningModule being trained. exception (BaseException): The exception that was raised. """ self._exception = True
[docs] @rank_zero_only def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Delete SLURM and Hydra files at the end of training if no exception occurred. Args: trainer (Trainer): The PyTorch Lightning Trainer. pl_module (LightningModule): The LightningModule being trained. """ if self._exception: print("[CleanUpCallback] Training failed, skipping file deletion.") return print("\n[CleanUpCallback] Cleaning up files after successful training:") for typ, f, sz in self._files_to_delete: if typ == "Hydra" and not self.delete_hydra: print(f" Skipping Hydra artifact: {f}") continue if self.dry_run: print(f" Dry run: would delete {f} ({human_size(sz)})") continue try: if os.path.isdir(f): shutil.rmtree(f) print(f" Deleted directory: {f}") else: os.remove(f) print(f" Deleted file: {f}") except Exception as e: print(f" Failed to delete {f}: {e}") print("[CleanUpCallback] Cleanup complete.")