from pathlib import Path
from typing import List, Dict, Optional, Any
import pandas as pd
import yaml
from loguru import logger
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import os
import concurrent.futures
# ========== Internal Functions ==========
def _optimize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Optimize DataFrame memory usage by downcasting numeric types and converting to categories.
This function reduces DataFrame memory footprint through several strategies:
- Downcast float columns to the smallest float dtype that can represent the values
- Downcast integer columns to the smallest integer dtype that can represent the values
- Convert Path objects to strings for serialization compatibility
- Convert low-cardinality object columns (< 50% unique values) to categorical dtype
Args:
df: Input DataFrame to optimize.
Returns:
Optimized copy of the DataFrame with reduced memory usage.
Notes:
- Creates a copy of the input DataFrame; original is not modified
- Path object conversion is logged as a warning
- Categorical conversion threshold is 50% unique values ratio
- Optimization progress is logged at INFO level
- Handles empty DataFrames gracefully
Example:
>>> df = pd.DataFrame(
... {
... "float_col": [1.0, 2.0, 3.0],
... "int_col": [1, 2, 3],
... "category_col": ["a", "b", "a", "b", "a"],
... }
... )
>>> df_optimized = _optimize_dataframe(df)
>>> df_optimized.memory_usage(deep=True).sum() < df.memory_usage(
... deep=True
... ).sum()
True
"""
logger.info("Optimizing DataFrame dtypes...")
df_opt = df.copy()
# Handle empty DataFrames early
if len(df_opt) == 0:
logger.info("DataFrame is empty, skipping optimization.")
return df_opt
for col in df_opt.select_dtypes(include=["float"]):
df_opt[col] = pd.to_numeric(df_opt[col], downcast="float")
for col in df_opt.select_dtypes(include=["integer"]):
df_opt[col] = pd.to_numeric(df_opt[col], downcast="integer")
for col in df_opt.columns:
# Convert PosixPath or WindowsPath to str
if df_opt[col].apply(lambda x: isinstance(x, Path)).any():
logger.warning(f"Column '{col}' contains Path objects, converting to str.")
df_opt[col] = df_opt[col].astype(str)
# Now handle categoricals
if df_opt[col].dtype == "object":
num_unique = df_opt[col].nunique()
num_total = len(df_opt[col])
# FIXED: Check for zero length before dividing
if num_total > 0 and num_unique / num_total < 0.5:
df_opt[col] = df_opt[col].astype("category")
logger.info("Optimization complete.")
return df_opt
def _save_variant(df, base_path, fmt, comp):
"""Save DataFrame in a specific format with compression and return file size.
Creates a file with naming pattern: {base_path}__{format}__{compression}.{ext}
Args:
df: DataFrame to save.
base_path: Base path/name for the output file (without extension).
fmt: File format. Supported values: 'parquet', 'csv', 'feather', 'pickle'.
comp: Compression algorithm. Valid options depend on format:
- parquet: 'brotli', 'gzip', 'snappy', 'zstd', 'none'
- csv: 'gzip', 'bz2', 'xz', 'zstd', 'zip'
- feather: 'zstd', 'lz4', 'uncompressed'
- pickle: 'infer', 'gzip', 'bz2', 'xz', 'zstd'
Returns:
Tuple of (filename, size_in_bytes) if successful, (None, None) if failed.
Notes:
- CSV files are saved without index
- File extensions are automatically appended based on format
- Failures are logged but don't raise exceptions (returns None, None)
- Unknown formats trigger a warning and return None, None
Examples:
>>> df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> filename, size = _save_variant(df, "data/output", "parquet", "zstd")
>>> filename
'data/output__parquet__zstd.parquet'
>>> size > 0
True
>>> # CSV with gzip compression
>>> filename, size = _save_variant(df, "data/output", "csv", "gzip")
>>> filename
'data/output__csv__gzip.csv.gz'
"""
filename = f"{base_path}__{fmt}__{comp}"
try:
if fmt == "parquet":
filename += ".parquet"
df.to_parquet(filename, compression=comp)
elif fmt == "csv":
ext = {
"gzip": "gz",
"bz2": "bz2",
"xz": "xz",
"zstd": "zst",
"zip": "zip",
}.get(comp, comp)
filename += f".csv.{ext}"
df.to_csv(filename, compression=comp, index=False)
elif fmt == "feather":
filename += ".feather"
df.to_feather(filename, compression=comp)
elif fmt == "pickle":
filename += ".pkl"
df.to_pickle(filename, compression=comp)
else:
logger.warning(f"Unknown format: {fmt}")
return None, None
size = os.path.getsize(filename)
logger.debug(f"Saved {filename} ({size / 1024:.2f} KB)")
return filename, size
except Exception as e:
logger.error(f"Failed to save {filename}: {e}")
return None, None
def _get_trials():
# Add more combinations as needed
return [
# Parquet
("parquet", "brotli"),
("parquet", "zstd"),
("parquet", "gzip"),
("parquet", "snappy"),
# Feather
("feather", "zstd"),
("feather", "lz4"),
("feather", "uncompressed"),
# CSV
("csv", "gzip"),
("csv", "bz2"),
("csv", "xz"),
("csv", "zstd"),
("csv", "zip"),
# Pickle
("pickle", "infer"),
# Add more if you want (e.g., external compression with subprocess)
]
def _parse_filename(filename):
"""Extract format and compression from filename following the naming convention.
Parses filenames created by _save_variant with pattern:
{base_path}__{format}__{compression}.{extension}
Args:
filename: Filename or path to parse. Can be absolute path or basename.
Returns:
Tuple of (format, compression) extracted from the filename.
Raises:
ValueError: If filename doesn't contain at least 3 parts separated by '__'.
Examples:
>>> _parse_filename("data__parquet__zstd.parquet")
('parquet', 'zstd')
>>> _parse_filename("/path/to/output__csv__gzip.csv.gz")
('csv', 'gzip')
>>> _parse_filename("output__feather__lz4.feather")
('feather', 'lz4')
>>> _parse_filename("invalid_filename.csv")
Traceback (most recent call last):
ValueError: Filename invalid_filename.csv does not match expected pattern.
Notes:
- Only the basename is used; directory paths are stripped
- The compression is extracted before any file extensions
- This is the inverse operation of the naming in _save_variant
"""
basename = os.path.basename(filename)
parts = basename.split("__")
if len(parts) < 3:
raise ValueError(f"Filename {filename} does not match expected pattern.")
fmt = parts[1]
comp = parts[2].split(".")[0]
return fmt, comp
# ========== User-Facing Functions ==========
def _get_extension(fmt: str, comp: str) -> str:
"""Get the appropriate file extension for format/compression combination."""
if fmt == "parquet":
return ".parquet"
elif fmt == "feather":
return ".feather"
elif fmt == "pickle":
# Pickle compression is in the file itself
return ".pkl"
elif fmt == "csv":
# CSV needs compression-specific extensions
comp_ext = {
"gzip": ".gz",
"bz2": ".bz2",
"xz": ".xz",
"zstd": ".zst",
"zip": ".zip",
}.get(comp, "")
return f".csv{comp_ext}"
return ""
def _save_variant(
df: pd.DataFrame, base_path: str, fmt: str, comp: str
) -> tuple[Optional[str], Optional[int]]:
"""Save DataFrame in a specific format with compression and return file size.
Creates a file with clean naming: {base_path}.{ext}
Uses a temporary suffix during trials to avoid conflicts.
Args:
df: DataFrame to save.
base_path: Base path/name for the output file (without extension).
fmt: File format. Supported values: 'parquet', 'csv', 'feather', 'pickle'.
comp: Compression algorithm. Valid options depend on format.
Returns:
Tuple of (filename, size_in_bytes) if successful, (None, None) if failed.
"""
ext = _get_extension(fmt, comp)
# Use temporary suffix to distinguish trial files
filename = f"{base_path}_trial_{fmt}_{comp}{ext}"
try:
if fmt == "parquet":
df.to_parquet(filename, compression=comp)
elif fmt == "csv":
df.to_csv(filename, compression=comp, index=False)
elif fmt == "feather":
df.to_feather(filename, compression=comp)
elif fmt == "pickle":
df.to_pickle(filename, compression=comp)
else:
logger.warning(f"Unknown format: {fmt}")
return None, None
size = os.path.getsize(filename)
logger.debug(f"Saved {filename} ({size / 1024:.2f} KB)")
return filename, size
except Exception as e:
logger.error(f"Failed to save {filename}: {e}")
return None, None
def _strip_extensions(base_path: str) -> str:
"""Remove all file extensions from base path to ensure clean naming.
Args:
base_path: Path that may contain extensions.
Returns:
Path with all extensions removed.
Examples:
>>> _strip_extensions("data.parquet")
'data'
>>> _strip_extensions("results.csv.gz")
'results'
>>> _strip_extensions("/path/to/file.feather")
'/path/to/file'
>>> _strip_extensions("data")
'data'
"""
path = Path(base_path)
parent = path.parent
name = path.name
# Keep removing suffixes until there are none left
while True:
name_path = Path(name)
if not name_path.suffix:
break
name = name_path.stem
# Reconstruct with original directory
if str(parent) != ".":
return str(parent / name)
return name
[docs]
def save_best_compressed(df: pd.DataFrame, base_path: str = "data") -> str:
"""Optimize DataFrame and save using the most space-efficient format/compression combination.
This function:
1. Optimizes DataFrame dtypes to reduce memory usage
2. Tries multiple format/compression combinations in parallel
3. Keeps only the smallest resulting file with a clean filename
4. Automatically cleans up all trial files
The following format/compression combinations are tested:
- Parquet: brotli, zstd, gzip, snappy
- Feather: zstd, lz4, uncompressed
- CSV: gzip, bz2, xz, zstd, zip
- Pickle: infer
Args:
df: DataFrame to save. Will be optimized before saving.
base_path: Base path/name for output file. Any extensions will be
automatically removed and replaced with the optimal format.
Default: 'data'
Returns:
Filename of the best-compressed file (smallest size) with clean extension.
Raises:
RuntimeError: If all save attempts fail (no files were created).
Notes:
- Uses parallel processing (ThreadPoolExecutor) for I/O-bound operations
- All trial files except the smallest are automatically deleted
- Original DataFrame is not modified (optimization works on a copy)
- Final file uses standard extensions (e.g., .parquet, .csv.gz, .feather)
- Any extensions in base_path are automatically stripped
- If multiple formats produce the same size, the first one sorted is kept
Performance:
- Time: Depends on number of trials and DataFrame size
- Memory: Peak usage ~2x DataFrame size (original + optimized copy)
- Disk: Temporarily creates all trial files before cleanup
Examples:
>>> df = pd.DataFrame(
... {
... "timestamp": pd.date_range("2024-01-01", periods=1000),
... "value": np.random.randn(1000),
... "category": np.random.choice(["A", "B", "C"], 1000),
... }
... )
>>> # Extension automatically added
>>> best_file = save_best_compressed(df, "results/experiment_01")
>>> print(best_file)
'results/experiment_01.parquet'
>>> # Extensions in input are stripped
>>> best_file = save_best_compressed(df, "data.csv")
>>> print(best_file)
'data.parquet' # .csv was stripped, optimal format chosen
>>> # Works with multiple extensions
>>> best_file = save_best_compressed(df, "output.csv.gz")
>>> print(best_file)
'output.feather' # All extensions stripped
See Also:
_optimize_dataframe: For details on DataFrame optimization
_get_trials: For the complete list of format/compression combinations
load_best_compressed: Smart loader that auto-detects format
"""
# Strip any extensions from base_path
base_path_clean = _strip_extensions(base_path)
if base_path_clean != base_path:
logger.info(f"Stripped extensions: '{base_path}' → '{base_path_clean}'")
logger.info("Starting best-compression save process...")
df_opt = _optimize_dataframe(df)
options = _get_trials()
results = []
with concurrent.futures.ThreadPoolExecutor() as executor:
# Store futures with their corresponding format/compression
future_to_options = {
executor.submit(_save_variant, df_opt, base_path_clean, fmt, comp): (
fmt,
comp,
)
for fmt, comp in options
}
for future in concurrent.futures.as_completed(future_to_options):
fmt, comp = future_to_options[
future
] # Get the format/compression for this future
filename, size = future.result()
if filename and size:
results.append((size, filename, fmt, comp))
if not results:
logger.error("No files were saved!")
raise RuntimeError("No files were saved!")
results.sort()
best_size, best_trial_file, best_fmt, best_comp = results[0]
# Rename best file to clean name
best_ext = _get_extension(best_fmt, best_comp)
final_filename = f"{base_path_clean}{best_ext}"
try:
os.rename(best_trial_file, final_filename)
logger.success(
f"Best compression: {best_fmt}/{best_comp} "
f"({best_size / 1024:.2f} KB) → {final_filename}"
)
except Exception as e:
logger.error(f"Failed to rename {best_trial_file} to {final_filename}: {e}")
final_filename = best_trial_file
# Cleanup other trial files
for _, fname, _, _ in results[1:]:
try:
os.remove(fname)
logger.debug(f"Removed {fname}")
except Exception as e:
logger.warning(f"Could not remove {fname}: {e}")
return final_filename
[docs]
def load_best_compressed(filename: str) -> pd.DataFrame:
"""Load a DataFrame from a compressed file with automatic format detection.
Automatically detects the file format from the extension and tries multiple
read methods if needed. This is a smart loader that works with any compressed
DataFrame file, not just those created by save_best_compressed.
Args:
filename: Path to the compressed file. Standard extensions are recognized:
- .parquet → Parquet format
- .feather → Feather format
- .pkl, .pickle → Pickle format
- .csv, .csv.gz, .csv.bz2, etc. → CSV format
Returns:
Loaded DataFrame with original dtypes preserved.
Raises:
FileNotFoundError: If the specified file doesn't exist.
ValueError: If no read method successfully loads the file.
Supported Formats:
- Parquet (compression auto-detected)
- Feather (compression auto-detected)
- CSV (compression auto-detected from extension)
- Pickle (compression auto-detected)
Notes:
- Extension-based detection is tried first for speed
- If extension is ambiguous/missing, tries all formats sequentially
- Compression is handled automatically by pandas
- Works with any properly formatted DataFrame file, not just from this module
Examples:
>>> # Load with clear extension
>>> df = load_best_compressed("data.parquet")
>>> # Load compressed CSV
>>> df = load_best_compressed("results.csv.gz")
>>> # Load with full path
>>> df = load_best_compressed("/path/to/experiment.feather")
>>> # Even works with ambiguous extensions (tries all methods)
>>> df = load_best_compressed("mystery_file.dat")
See Also:
save_best_compressed: Saves DataFrame with optimal compression
"""
if not os.path.exists(filename):
raise FileNotFoundError(f"File not found: {filename}")
logger.info(f"Loading {filename}...")
# Extension-based detection
lower_filename = filename.lower()
# Define read strategies in order of likelihood
strategies = []
if lower_filename.endswith(".parquet"):
strategies.append(("parquet", pd.read_parquet))
elif lower_filename.endswith(".feather"):
strategies.append(("feather", pd.read_feather))
elif lower_filename.endswith((".pkl", ".pickle")):
strategies.append(("pickle", pd.read_pickle))
elif ".csv" in lower_filename:
strategies.append(("csv", pd.read_csv))
# Fallback: try all methods if extension doesn't match
if not strategies:
logger.warning("Unknown extension, trying all formats...")
strategies = [
("parquet", pd.read_parquet),
("feather", pd.read_feather),
("pickle", pd.read_pickle),
("csv", pd.read_csv),
]
else:
# Add other formats as fallbacks
all_methods = [
("parquet", pd.read_parquet),
("feather", pd.read_feather),
("pickle", pd.read_pickle),
("csv", pd.read_csv),
]
# Add methods we haven't tried yet
for method in all_methods:
if method not in strategies:
strategies.append(method)
# Try each strategy
last_error = None
for fmt, read_func in strategies:
try:
logger.debug(f"Trying to load as {fmt}...")
df = read_func(filename)
logger.success(f"Successfully loaded as {fmt} format")
return df
except Exception as e:
logger.debug(f"Failed to load as {fmt}: {e}")
last_error = e
continue
# If we get here, nothing worked
logger.error(f"Could not load {filename} with any known format")
raise ValueError(
f"Failed to load {filename}. Tried: {', '.join(s[0] for s in strategies)}. "
f"Last error: {last_error}"
)
[docs]
class CSVLogAutoSummarizer:
"""Automatically discovers and summarizes PyTorch Lightning CSVLogger metrics from Hydra multirun sweeps.
Handles arbitrary directory layouts, sparse metrics, multiple versions (preemption/resume), and aggregates
config/hparams metadata into a single DataFrame.
Features:
- Recursively finds metrics CSVs using common patterns.
- Infers run root by searching for .hydra/config.yaml (falls back to metrics parent).
- Handles multiple metrics files per run with configurable grouping strategies:
latest_mtime, latest_epoch, latest_step, or merge.
- Robust CSV parsing (delimiter sniffing, repeated-header cleanup, type coercion).
- Sparse-metrics aware: last values are last non-NaN; best values ignore NaNs.
- Loads and flattens Hydra config, overrides (list or dict), and hparams metadata.
Args:
base_dir: Root directory to search for metrics files.
monitor_keys: Metrics to summarize (if None, auto-infer).
include_globs: Only include files matching these globs (relative to base_dir).
exclude_globs: Exclude files matching these globs (e.g., '**/checkpoints/**').
forward_fill_last: If True, forward-fill the frame (after sorting) before computing last.* summaries.
"""
METRICS_PATTERNS = ["**/metrics.csv", "**/csv/metrics.csv", "**/*metrics*.csv"]
def __init__(
self,
agg: Optional[callable] = None,
monitor_keys: Optional[List[str]] = None,
include_globs: Optional[List[str]] = None,
exclude_globs: Optional[List[str]] = None,
max_workers: Optional[int] = 10,
):
self.agg = agg
self.monitor_keys = monitor_keys
self.include_globs = include_globs or []
self.exclude_globs = exclude_globs or []
self.max_workers = max_workers
[docs]
def collect(self, base_dir) -> pd.DataFrame:
"""Discover, summarize, and aggregate all runs into a DataFrame.
Args:
base_dir: Root directory to search for metrics files.
Returns:
pd.DataFrame: One row per selected metrics source (per file or per run root),
with flattened columns such as:
- last.val_accuracy
- best.val_loss, best.val_loss.step, best.val_loss.epoch
- config.optimizer.lr, override.0 (or override as joined string), hparams.seed
Also includes 'metrics_path' and 'run_root'.
"""
if type(base_dir) not in [list, tuple]:
base_dir = [base_dir]
dfs = []
for b in base_dir:
dfs.append(self._single_collect(b))
return pd.concat(dfs, ignore_index=True)
def _single_collect(self, base_dir) -> pd.DataFrame:
"""Discover, summarize, and aggregate all runs into a DataFrame.
Args:
base_dir: Root directory to search for metrics files.
Returns:
pd.DataFrame: One row per selected metrics source (per file or per run root),
with flattened columns such as:
- last.val_accuracy
- best.val_loss, best.val_loss.step, best.val_loss.epoch
- config.optimizer.lr, override.0 (or override as joined string), hparams.seed
Also includes 'metrics_path' and 'run_root'.
"""
self.base_dir = Path(base_dir)
metrics_files = self._find_metrics_files()
run_root_to_files: Dict[Path, List[Path]] = {}
for f in tqdm(metrics_files, desc="compiling root paths"):
root = self._find_run_root(f)
run_root_to_files.setdefault(root, []).append(f)
items = list(run_root_to_files.items())
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
summaries = list(
tqdm(
executor.map(self._merge_metrics_files, items),
desc="Loading",
total=len(items),
)
)
summaries = [s for s in summaries if s is not None]
if not summaries:
return pd.DataFrame()
return pd.concat(summaries, ignore_index=True)
# ----------------------------
# Discovery and selection
# ----------------------------
def _find_metrics_files(self) -> List[Path]:
"""Recursively find all metrics CSV files matching known patterns, applying include/exclude globs."""
files: List[Path] = []
for pat in self.METRICS_PATTERNS:
files.extend(self.base_dir.rglob(pat))
# Unique files
files = list({f.resolve() for f in files if f.is_file()})
rel_files = [f.relative_to(self.base_dir) for f in files]
if self.include_globs:
files = [
f
for f, rel in zip(files, rel_files)
if any(rel.match(g) for g in self.include_globs)
]
if self.exclude_globs:
files = [
f
for f, rel in zip(files, rel_files)
if not any(rel.match(g) for g in self.exclude_globs)
]
logger.info(f"Discovered {len(files)} metrics files.")
return files
def _find_run_root(self, metrics_file: Path) -> Path:
"""Walk up to nearest ancestor with .hydra/config.yaml; else use immediate parent."""
for parent in [metrics_file.parent] + list(metrics_file.parents):
hydra_dir = parent / ".hydra"
if (hydra_dir / "config.yaml").exists():
return parent
return metrics_file.parent
def _merge_metrics_files(self, args) -> Optional[pd.DataFrame]:
"""Merge multiple metrics CSVs for a run root, deduplicate by step/epoch.
Keep the last occurrence for any duplicated step/epoch pair.
"""
root, files = args
dfs = [self._read_data(f) for f in files]
dfs = [df for df in dfs if df is not None and not df.empty]
if not dfs:
return None
df = pd.concat(dfs, ignore_index=True)
df["root"] = root
if self.agg is not None:
df = self.agg(df)
if isinstance(df, pd.Series):
df = df.to_frame().T
elif not isinstance(df, pd.DataFrame):
raise RuntimeError("Can only be series or dataframe")
return df
# ----------------------------
# CSV reading and summarization
# ----------------------------
def _read_data(self, path: Path) -> Optional[pd.DataFrame]:
"""Robustly read a metrics CSV, handling delimiter, repeated headers, and sparse rows.
Returns None if file cannot be read or is empty.
"""
df = pd.read_csv(path)
# Drop unnamed columns, strip column names
df = df.loc[:, ~df.columns.str.contains("^Unnamed")]
df.columns = df.columns.str.strip()
# Try to convert numeric columns where possible
df = df.apply(pd.to_numeric)
df.ffill(inplace=True)
# Diagnostics if all values are NaN (excluding axis/time cols)
metric_cols = [c for c in df.columns if c not in ("step", "epoch", "time")]
if metric_cols and df[metric_cols].isna().all().all():
logger.warning(f"All metric values are NaN in {path}")
logger.debug(f"Dtypes:\n{df.dtypes}\nHead:\n{df.head()}")
hparams = self._find_hparams(path.parent / "hparams.yaml")
for k, v in hparams.items():
# Convert lists/dicts to string representation
if isinstance(v, (list, dict)):
df[f"config/{k}"] = str(v)
else:
df[f"config/{k}"] = v
# hparams = self._find_hparams(path.parent / "hparams.yaml")
# for k, v in list(hparams.items()):
# hparams[f"config/{k}"] = v
# del hparams[k]
# df[list(hparams.keys())] = list(hparams.values())
return df
# ----------------------------
# Metadata loading and flattening
# ----------------------------
def _load_yaml(self, path: Path) -> Any:
"""Load a YAML file. Returns the parsed object, which may be dict, list, or scalar.
Returns {} if file is missing or unreadable.
"""
if not path.exists():
return {}
try:
with open(path, "r") as f:
obj = yaml.safe_load(f)
return {} if obj is None else obj
except Exception as e:
logger.warning(f"Failed to load YAML {path}: {e}")
return {}
def _find_hparams(self, start_dir: Path) -> Any:
"""Search upward from start_dir for hparams.yaml (first found). Return the parsed object (dict/list/scalar) or {}."""
for parent in [start_dir] + list(start_dir.parents):
hp = parent / "hparams.yaml"
if hp.exists():
return self._load_yaml(hp)
return {}