stable_pretraining.utils package#
Submodules#
stable_pretraining.utils.batch_utils module#
Utility functions for handling batch and outputs dictionaries in callbacks.
- stable_pretraining.utils.batch_utils.detach_tensors(obj: Any) Any[source]#
Recursively traverse an object and return an equivalent structure with all torch tensors detached.
Preserves structure, types, and shared references.
Handles cycles and arbitrary Python objects (including __dict__ and __slots__).
Does not mutate the input; only rebuilds containers if needed.
torch.nn.Parameter is replaced with a detached Tensor (not Parameter).
Optionally supports attrs classes if ‘attr’ is installed.
- Parameters:
obj – The input object (can be arbitrarily nested).
- Returns:
A new object with all torch tensors detached, or the original object if no tensors found.
- Performance notes:
Uses memoization to avoid redundant work and preserve shared/cyclic structure.
Avoids unnecessary copies: unchanged subtrees are returned as-is (same id).
Shallow-copies objects with __dict__ or __slots__ (does not call __init__).
- stable_pretraining.utils.batch_utils.get_data_from_batch_or_outputs(key: Iterable[str] | str, batch: Dict[str, Any], outputs: Dict[str, Any] | None = None, caller_name: str = 'Callback') Any | None[source]#
Get data from either outputs or batch dictionary.
In PyTorch Lightning, the outputs parameter in callbacks contains the return value from training_step/validation_step, while batch contains the original input. Since forward methods may modify batch in-place but Lightning creates a copy for outputs, we need to check both.
- Parameters:
key – The key(s) to look for in the dictionaries
batch – The original batch dictionary
outputs – The outputs dictionary from training/validation step
caller_name – Name of the calling function/class for logging
- Returns:
The data associated with the key, or None if not found
stable_pretraining.utils.config module#
Configuration utilities and model manipulation helpers.
- stable_pretraining.utils.config.adapt_resnet_for_lowres(model)[source]#
Adapt a ResNet model for low resolution images.
Modifies the first convolution layer to use 3x3 kernels with stride 1 and removes the max pooling layer.
- Parameters:
model – ResNet model to adapt
- Returns:
Modified model
- stable_pretraining.utils.config.execute_from_config(manager, cfg)[source]#
Execute a function with support for submitit job submission.
If submitit configuration is present, submits the job to a cluster. Otherwise executes locally.
- Parameters:
manager – Function or callable to execute
cfg – Configuration dictionary
- Returns:
Result of the executed function
- stable_pretraining.utils.config.find_module(model: Module, module: Module)[source]#
Find all instances of a module type in a model.
- Parameters:
model – Model to search in
module – Module class to search for
- Returns:
Tuple of (names, modules) where names are the module paths and modules are the actual module instances
- stable_pretraining.utils.config.is_dist() bool[source]#
Returns True if torch.distributed is available and initialized.
- stable_pretraining.utils.config.load_hparams_from_ckpt(path: str) Any[source]#
Loads a checkpoint safely in both distributed and non-distributed settings.
If torch.distributed is initialized, only src_rank loads from disk, then broadcasts to all other ranks.
If not, loads directly from disk.
- Parameters:
path – Path to checkpoint file.
- Returns:
The loaded hparams.
- stable_pretraining.utils.config.replace_module(model, replacement_mapping)[source]#
Replace modules in a model based on a mapping function.
- Parameters:
model – PyTorch model to modify
replacement_mapping – Function that takes (name, module) and returns the replacement module
- Returns:
Modified model
stable_pretraining.utils.data_generation module#
Sample generation utilities for SSL experiments.
- stable_pretraining.utils.data_generation.generate_dae_samples(x, n, eps, num_workers=10)[source]#
Generate samples for Denoising Autoencoder (DAE) training.
- Parameters:
x – List of input images
n – Number of noisy versions per image
eps – Noise level (variance)
num_workers – Number of parallel workers
- Returns:
Tuple of (noisy_images, similarity_matrix)
- stable_pretraining.utils.data_generation.generate_dm_samples(x, n, betas, i, num_workers=10)[source]#
Generate samples for Diffusion Model training.
- Parameters:
x – List of input images
n – Number of noisy versions per timestep
betas – Noise schedule beta values
i – Timestep indices to use
num_workers – Number of parallel workers
- Returns:
Tuple of (noisy_images, similarity_matrix)
- stable_pretraining.utils.data_generation.generate_ssl_samples(x, n, num_workers=10)[source]#
Generate augmented samples for self-supervised learning.
Creates n augmented versions of each image.
- Parameters:
x – List of input images
n – Number of augmented versions per image
num_workers – Number of parallel workers
- Returns:
Tuple of (augmented_images, similarity_matrix)
- stable_pretraining.utils.data_generation.generate_sup_samples(x, y, n, num_workers=10)[source]#
Generate samples for supervised learning with class structure.
Only includes classes with at least n samples.
- Parameters:
x – List of input images
y – Class labels
n – Minimum samples per class
num_workers – Number of parallel workers
- Returns:
Tuple of (processed_images, class_similarity_matrix)
stable_pretraining.utils.distance_metrics module#
Distance metric functions for computing pairwise distances between tensors.
- stable_pretraining.utils.distance_metrics.compute_pairwise_distances(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean') Tensor[source]#
Compute pairwise distances between two sets of vectors.
- Parameters:
x – Tensor of shape (n, d) containing n vectors of dimension d
y – Tensor of shape (m, d) containing m vectors of dimension d
metric – Distance metric to use. Options: - “euclidean”: L2 distance - “squared_euclidean”: Squared L2 distance - “cosine”: Cosine distance (1 - cosine_similarity) - “manhattan”: L1 distance
- Returns:
Distance matrix of shape (n, m) where element (i, j) is the distance between x[i] and y[j]
- stable_pretraining.utils.distance_metrics.compute_pairwise_distances_chunked(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean', chunk_size: int = 1024) Tensor[source]#
Memory-efficient computation of pairwise distances using chunking.
- Parameters:
x – Tensor of shape (n, d) containing n vectors of dimension d
y – Tensor of shape (m, d) containing m vectors of dimension d
metric – Distance metric to use
chunk_size – Process y in chunks of this size to save memory
- Returns:
Distance matrix of shape (n, m)
stable_pretraining.utils.distributed module#
Distributed training utilities.
- class stable_pretraining.utils.distributed.FullGatherLayer(*args, **kwargs)[source]#
Bases:
FunctionGather tensors from all process and support backward propagation.
Supports backward propagation for the gradients across processes.
- static backward(ctx, grad)[source]#
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjpfunction.)It must accept a context
ctxas the first argument, followed by as many outputs as theforward()returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_gradas a tuple of booleans representing whether each input needs gradient. E.g.,backward()will havectx.needs_input_grad[0] = Trueif the first input toforward()needs gradient computed w.r.t. the output.
- static forward(ctx, x)[source]#
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See Combined or separate forward() and setup_context() for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()staticmethod to handle setting up thectxobject.outputis the output of the forward,inputsare a Tuple of inputs to the forward.See Extending torch.autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()if they are intended to be used inbackward(equivalently,vjp) orctx.save_for_forward()if they are intended to be used for injvp.
- stable_pretraining.utils.distributed.all_gather(tensor, *args, **kwargs)[source]#
Gather tensors from all processes.
- Parameters:
tensor – The tensor to gather
*args – Additional arguments for all_gather
**kwargs – Additional keyword arguments for all_gather
- Returns:
Tuple containing the gathered tensors
stable_pretraining.utils.error_handling module#
- stable_pretraining.utils.error_handling.with_hf_retry_ratelimit(func, *args, delay=10, max_attempts=100, **kwargs)[source]#
Calls the given function with retry logic for HTTP 429 (Too Many Requests) errors.
This function attempts to call
func(*args, **kwargs). If a rate-limiting error (HTTP 429) is encountered—detected via exception type, status code, or error message—it will wait for the duration specified by the HTTPRetry-Afterheader (if present), or fall back to thedelayparameter, and then retry. Retries continue up tomax_attemptstimes. Non-429 errors are immediately re-raised. If all attempts fail due to 429, the last exception is raised.- Exceptions handled:
huggingface_hub.utils.HfHubHTTPError
requests.exceptions.HTTPError
OSError
429 detection is performed by checking the exception’s
response.status_code(if available) or by searching for ‘429’ or ‘Too Many Requests’ in the exception message.- Parameters:
func (callable) – The function to call.
*args – Positional arguments to pass to
func.delay (int, optional) – Default wait time (in seconds) between retries if
Retry-Afteris not provided. Defaults to 10.max_attempts (int, optional) – Maximum number of attempts before giving up. Defaults to 100.
**kwargs – Keyword arguments to pass to
func.
- Returns:
The return value of
func(*args, **kwargs)if successful.- Raises:
Exception – The original exception if a non-429 error occurs, or if all attempts fail.
Example
>>> from transformers import AutoModel >>> model = with_hf_retry_ratelimit( ... AutoModel.from_pretrained, ... "facebook/ijepa_vith14_1k", ... delay=10, ... max_attempts=5, ... )
stable_pretraining.utils.gdrive_utils module#
Google Drive Background Uploader.
Upload files to Google Drive without blocking your application. Features: auto-versioning, background processing, comprehensive logging.
Example
>>> uploader = GDriveUploader("MyProject", "credentials.json")
>>> uploader.upload_file("large_file.mp4") # Returns immediately!
>>> uploader.wait_for_uploads() # Optional: wait for completion
- class stable_pretraining.utils.gdrive_utils.GDriveUploader(folder_name: str, credentials_path: str, parent_folder_id: str | None = None, callback: Callable[[str, str | None, bool], None] | None = None)[source]#
Bases:
objectBackground Google Drive uploader with automatic folder versioning.
All uploads happen in a background thread - your code never blocks! Automatically creates versioned folders (MyFolder_v2, MyFolder_v3, etc).
- Parameters:
folder_name – Name of the Drive folder to create
credentials_path – Path to service account JSON credentials
parent_folder_id – Optional parent folder ID (None = root)
callback – Optional function(file_path, file_id, success) called on completion
Example
>>> def notify(path, fid, ok): ... print(f"{'✓' if ok else '✗'} {path}") >>> >>> uploader = GDriveUploader("Backups", "creds.json", callback=notify) >>> uploader.upload_file("data.csv") # Non-blocking! >>> uploader.upload_file("logs.txt") >>> uploader.wait_for_uploads() # Wait for all to finish
- upload_file(file_path: str, custom_name: str | None = None, subfolder_id: str | None = None) None[source]#
Queue a file for background upload. Returns immediately!
- Parameters:
file_path – Path to file to upload
custom_name – Optional custom name in Drive
subfolder_id – Optional subfolder ID (defaults to main folder)
Example
>>> uploader.upload_file("video.mp4") >>> uploader.upload_file("report.pdf", custom_name="Q4_report.pdf")
- wait_for_uploads(timeout: float | None = None) bool[source]#
Wait for all queued uploads to complete.
- Parameters:
timeout – Max seconds to wait (None = wait forever)
- Returns:
True if completed, False if timeout/error
Example
>>> uploader.upload_file("file1.txt") >>> uploader.upload_file("file2.txt") >>> uploader.wait_for_uploads(timeout=300) # Wait max 5 min
stable_pretraining.utils.inspection_utils module#
Function inspection and general helper utilities.
- stable_pretraining.utils.inspection_utils.broadcast_param_to_list(param: Any, target_length: int, param_name: str) List[Any][source]#
Broadcast a parameter value to create a list of specified length.
This function handles the common pattern of accepting either: - None: creates a list of None values - A single value: broadcasts to all positions - A single-element list/tuple: broadcasts the element to all positions - A list/tuple of correct length: returns as-is
- Parameters:
param – The parameter to broadcast (can be None, single value, or list/tuple)
target_length – The desired length of the output list
param_name – Name of the parameter for error messages
- Returns:
List of values with length matching target_length
- Raises:
ValueError – If param is a list/tuple with length > 1 that doesn’t match target_length
Examples
>>> broadcast_param_to_list(None, 3, "dims") [None, None, None] >>> broadcast_param_to_list(5, 3, "dims") [5, 5, 5] >>> broadcast_param_to_list([5], 3, "dims") [5, 5, 5] >>> broadcast_param_to_list([1, 2, 3], 3, "dims") [1, 2, 3]
stable_pretraining.utils.lightning_patch module#
Monkey-patch for PyTorch Lightning to support manual optimization with Trainer parameters.
This patch modifies Lightning’s validation to transfer gradient_clip_val and accumulate_grad_batches to alternative attributes instead of raising errors.
- stable_pretraining.utils.lightning_patch.apply_manual_optimization_patch()[source]#
Apply the monkey-patch to Lightning’s manual optimization validation.
This patch modifies the __verify_manual_optimization_support function to: 1. Transfer gradient_clip_val to gradient_clip_val_ 2. Transfer accumulate_grad_batches to accumulate_grad_batches_ 3. Clear the original values to avoid Lightning’s error
This allows users to use standard Trainer parameters even with manual optimization.
stable_pretraining.utils.log_reader module#
Unified log reader for local and wandb logs.
- class stable_pretraining.utils.log_reader.LocalLogReader(num_workers: int = 8)[source]#
Bases:
LogReaderReader for local jsonl log files.
- read(path: str | Path) List[Dict[str, Any]][source]#
Load values from a single run directory.
- Parameters:
path – Path to the run directory
- Returns:
List of log entries
- class stable_pretraining.utils.log_reader.LogReader[source]#
Bases:
ABCAbstract base class for log readers.
- class stable_pretraining.utils.log_reader.TableFormatter[source]#
Bases:
objectFormat experiment results as tables for analysis.
- static create_table(dfs: Dict[str, DataFrame], configs: Dict[str, Any], value: str, row: str, column: str, agg: Callable, filters: Dict[str, Any] | None = None) DataFrame[source]#
Format a pandas DataFrame as a table given the user args.
- Parameters:
dfs – Dictionary of DataFrames (one per run)
configs – Dictionary of configs (one per run)
value – Name of the column in dfs to use as values
row – Name of the column in configs to use as row
column – Name of the column in configs to use as column
agg – Aggregator function if many values are present
filters – Optional filters to apply to the data
- Returns:
Formatted table as DataFrame
- static tabulate_runs(configs: DataFrame, runs: List[Any], value: str, ignore: List[str] | None = None) DataFrame[source]#
Create a pivot table from configs and runs for a specific value.
- Parameters:
configs – DataFrame of run configurations
runs – List of run data
value – Value to extract from runs
ignore – Columns to ignore
- Returns:
Pivot table as DataFrame
- class stable_pretraining.utils.log_reader.WandbLogReader(num_workers: int = 10)[source]#
Bases:
LogReaderReader for Weights & Biases logs.
- read(entity: str, project: str, run_id: str, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, _tqdm_disable: bool = False) Tuple[DataFrame, Dict[str, Any]][source]#
Download data for a single wandb run.
- Parameters:
entity – Wandb entity name
project – Wandb project name
run_id – Run ID
min_step – Minimum step to download
max_step – Maximum step to download (-1 for all)
keys – Specific keys to download
_tqdm_disable – Whether to disable tqdm progress bar
- Returns:
Tuple of (data DataFrame, config dict)
- read_project(entity: str, project: str, filters: Dict[str, Any] | None = None, order: str = '+created_at', per_page: int = 50, include_sweeps: bool = True, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, return_summary: bool = False) Tuple[Dict[str, DataFrame], Dict[str, Dict]] | DataFrame[source]#
Download configs and data from a wandb project.
- Parameters:
entity – Wandb entity name
project – Wandb project name
filters – Optional filters for runs
order – Sort order for runs
per_page – Number of runs per page
include_sweeps – Whether to include sweep runs
min_step – Minimum step to download
max_step – Maximum step to download
keys – Specific keys to download
return_summary – If True, return only summary DataFrame
- Returns:
Tuple of (dfs dict, configs dict) If return_summary is True: Summary DataFrame
- Return type:
If return_summary is False
- stable_pretraining.utils.log_reader.alphanum_key(key: str) List[int | str][source]#
Convert a string to a list of mixed numbers and strings for natural sorting.
- stable_pretraining.utils.log_reader.create_results_table(dfs: ~typing.Dict[str, ~pandas.core.frame.DataFrame], configs: ~typing.Dict[str, ~typing.Any], value: str, row: str, column: str, agg: ~typing.Callable = <function mean>, filters: ~typing.Dict[str, ~typing.Any] | None = None) DataFrame[source]#
Convenience function to create a results table.
- Parameters:
dfs – Dictionary of DataFrames (one per run)
configs – Dictionary of configs (one per run)
value – Name of the column in dfs to use as values
row – Name of the column in configs to use as row
column – Name of the column in configs to use as column
agg – Aggregator function (default: mean)
filters – Optional filters to apply
- Returns:
Formatted table as DataFrame
- stable_pretraining.utils.log_reader.flatten_config(config: Dict[str, Any]) Dict[str, Any][source]#
Flatten nested config dictionaries into a single level.
- Parameters:
config – Nested configuration dictionary
- Returns:
Flattened configuration dictionary
- stable_pretraining.utils.log_reader.natural_sort(values: List[str]) List[str][source]#
Sort a list of strings naturally (handling numbers properly).
- stable_pretraining.utils.log_reader.read_local_logs(path: str | Path, num_workers: int = 8) List[Dict[str, Any]][source]#
Convenience function to read local logs.
- Parameters:
path – Path to the run directory
num_workers – Number of parallel workers
- Returns:
List of log entries
- stable_pretraining.utils.log_reader.read_local_project(folder: str | Path, num_workers: int = 8) Tuple[DataFrame, List[List[Dict[str, Any]]]][source]#
Convenience function to read a local project.
- Parameters:
folder – Path to the project folder
num_workers – Number of parallel workers
- Returns:
Tuple of (configs DataFrame, list of values)
- stable_pretraining.utils.log_reader.read_wandb_project(entity: str, project: str, filters: Dict[str, Any] | None = None, order: str = '+created_at', per_page: int = 50, include_sweeps: bool = True, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, num_workers: int = 10, return_summary: bool = False) Tuple[Dict[str, DataFrame], Dict[str, Dict]] | DataFrame[source]#
Convenience function to read a wandb project.
- Parameters:
entity – Wandb entity name
project – Wandb project name
filters – Optional filters for runs
order – Sort order for runs
per_page – Number of runs per page
include_sweeps – Whether to include sweep runs
min_step – Minimum step to download
max_step – Maximum step to download
keys – Specific keys to download
num_workers – Number of parallel workers
return_summary – If True, return only summary DataFrame
- Returns:
Tuple of (dfs dict, configs dict) If return_summary is True: Summary DataFrame
- Return type:
If return_summary is False
- stable_pretraining.utils.log_reader.read_wandb_run(entity: str, project: str, run_id: str, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, num_workers: int = 10) Tuple[DataFrame, Dict[str, Any]][source]#
Convenience function to read a wandb run.
- Parameters:
entity – Wandb entity name
project – Wandb project name
run_id – Run ID
min_step – Minimum step to download
max_step – Maximum step to download
keys – Specific keys to download
num_workers – Number of parallel workers
- Returns:
Tuple of (data DataFrame, config dict)
stable_pretraining.utils.nn_modules module#
Neural network modules and utilities.
- class stable_pretraining.utils.nn_modules.BatchNorm1dNoBias(*args, **kwargs)[source]#
Bases:
BatchNorm1dBatchNorm1d with learnable scale but no learnable bias (center=False).
This is used in contrastive learning methods like SimCLR where the final projection layer uses batch normalization with scale (gamma) but without bias (beta). This follows the original SimCLR implementation where the bias term is removed from the final BatchNorm layer.
The bias is frozen at 0 and set to non-trainable, while the weight (scale) parameter remains learnable.
Example
```python # SimCLR-style projector projector = nn.Sequential(
nn.Linear(2048, 2048, bias=False), nn.BatchNorm1d(2048), nn.ReLU(inplace=True), nn.Linear(2048, 128, bias=False), spt.utils.nn_modules.BatchNorm1dNoBias(128), # Final layer: no bias
)#
Note
This is equivalent to TensorFlow’s BatchNorm with center=False, scale=True.
- class stable_pretraining.utils.nn_modules.EMA(alpha: float)[source]#
Bases:
ModuleExponential Moving Average module.
Maintains an exponential moving average of input tensors.
- Parameters:
alpha – Smoothing factor between 0 and 1. 0 = no update (always return first value) 1 = no smoothing (always return current value)
- class stable_pretraining.utils.nn_modules.ImageToVideoEncoder(encoder: Module)[source]#
Bases:
ModuleWrapper to apply an image encoder to video data by processing each frame independently.
This module takes video data with shape (batch, time, channel, height, width) and applies an image encoder to each frame, returning the encoded features.
- Parameters:
encoder (torch.nn.Module) – The image encoder module to apply to each frame.
- forward(video)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.utils.nn_modules.L2Norm(*args, **kwargs)[source]#
Bases:
ModuleL2 normalization layer that normalizes input to unit length.
Normalizes the input tensor along the last dimension to have unit L2 norm. Commonly used in DINO before the prototypes layer.
Example
```python projector = nn.Sequential(
nn.Linear(512, 2048), nn.GELU(), nn.Linear(2048, 256), spt.utils.nn_modules.L2Norm(), # Normalize to unit length nn.Linear(256, 4096, bias=False), # Prototypes
)#
- class stable_pretraining.utils.nn_modules.Normalize(*args, **kwargs)[source]#
Bases:
ModuleNormalize tensor and scale by square root of number of elements.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.utils.nn_modules.OrderedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#
Bases:
ModuleA queue that maintains insertion order of elements.
Similar to UnsortedQueue but tracks the order in which items were inserted, allowing retrieval in the original insertion order even after wraparound.
- Parameters:
max_length – Maximum number of elements to store in the queue
shape – Shape of each element (excluding batch dimension). Can be int or tuple
dtype – Data type of the tensors to store
- append(item)[source]#
Append item(s) to the queue with order tracking.
- Parameters:
item – Tensor to append. First dimension is batch size.
- Returns:
Current contents of the queue in insertion order
- get()[source]#
Get current contents sorted by insertion order.
- Returns:
Tensor containing items sorted by their original insertion order
- get_unsorted()[source]#
Get current contents without sorting (like UnsortedQueue).
- Returns:
Tensor containing items in buffer order
- load_state_dict(state_dict, strict=True, assign=False)[source]#
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield ofParameterfor which the value from the module is preserved. Default:False
- Returns:
missing_keysis a list of str containing any keys that are expectedby this module but missing from the provided
state_dict.
unexpected_keysis a list of str containing the keys that are notexpected by this module but present in the provided
state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- class stable_pretraining.utils.nn_modules.UnsortedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#
Bases:
ModuleA queue data structure that stores tensors with a maximum length.
This module implements a circular buffer that stores tensors up to a maximum length. When the queue is full, new items overwrite the oldest ones.
- Parameters:
max_length – Maximum number of elements to store in the queue
shape – Shape of each element (excluding batch dimension). Can be int or tuple
dtype – Data type of the tensors to store
stable_pretraining.utils.read_csv_logger module#
- class stable_pretraining.utils.read_csv_logger.CSVLogAutoSummarizer(agg: callable | None = None, monitor_keys: List[str] | None = None, include_globs: List[str] | None = None, exclude_globs: List[str] | None = None, max_workers: int | None = 10)[source]#
Bases:
objectAutomatically discovers and summarizes PyTorch Lightning CSVLogger metrics from Hydra multirun sweeps.
Handles arbitrary directory layouts, sparse metrics, multiple versions (preemption/resume), and aggregates config/hparams metadata into a single DataFrame. Features: - Recursively finds metrics CSVs using common patterns. - Infers run root by searching for .hydra/config.yaml (falls back to metrics parent). - Handles multiple metrics files per run with configurable grouping strategies:
latest_mtime, latest_epoch, latest_step, or merge.
Robust CSV parsing (delimiter sniffing, repeated-header cleanup, type coercion).
Sparse-metrics aware: last values are last non-NaN; best values ignore NaNs.
Loads and flattens Hydra config, overrides (list or dict), and hparams metadata.
Args: base_dir: Root directory to search for metrics files. monitor_keys: Metrics to summarize (if None, auto-infer). include_globs: Only include files matching these globs (relative to base_dir). exclude_globs: Exclude files matching these globs (e.g., ‘/checkpoints/’). forward_fill_last: If True, forward-fill the frame (after sorting) before computing last.* summaries.
- METRICS_PATTERNS = ['**/metrics.csv', '**/csv/metrics.csv', '**/*metrics*.csv']#
- collect(base_dir) DataFrame[source]#
Discover, summarize, and aggregate all runs into a DataFrame.
Args: base_dir: Root directory to search for metrics files.
- Returns:
One row per selected metrics source (per file or per run root), with flattened columns such as: - last.val_accuracy - best.val_loss, best.val_loss.step, best.val_loss.epoch - config.optimizer.lr, override.0 (or override as joined string), hparams.seed Also includes ‘metrics_path’ and ‘run_root’.
- Return type:
pd.DataFrame
- stable_pretraining.utils.read_csv_logger.load_best_compressed(filename: str) DataFrame[source]#
Load a DataFrame from a compressed file with automatic format detection.
Automatically detects the file format from the extension and tries multiple read methods if needed. This is a smart loader that works with any compressed DataFrame file, not just those created by save_best_compressed.
- Parameters:
filename – Path to the compressed file. Standard extensions are recognized: - .parquet → Parquet format - .feather → Feather format - .pkl, .pickle → Pickle format - .csv, .csv.gz, .csv.bz2, etc. → CSV format
- Returns:
Loaded DataFrame with original dtypes preserved.
- Raises:
FileNotFoundError – If the specified file doesn’t exist.
ValueError – If no read method successfully loads the file.
- Supported Formats:
Parquet (compression auto-detected)
Feather (compression auto-detected)
CSV (compression auto-detected from extension)
Pickle (compression auto-detected)
Notes
Extension-based detection is tried first for speed
If extension is ambiguous/missing, tries all formats sequentially
Compression is handled automatically by pandas
Works with any properly formatted DataFrame file, not just from this module
Examples
>>> # Load with clear extension >>> df = load_best_compressed("data.parquet")
>>> # Load compressed CSV >>> df = load_best_compressed("results.csv.gz")
>>> # Load with full path >>> df = load_best_compressed("/path/to/experiment.feather")
>>> # Even works with ambiguous extensions (tries all methods) >>> df = load_best_compressed("mystery_file.dat")
See also
save_best_compressed: Saves DataFrame with optimal compression
- stable_pretraining.utils.read_csv_logger.save_best_compressed(df: DataFrame, base_path: str = 'data') str[source]#
Optimize DataFrame and save using the most space-efficient format/compression combination.
This function: 1. Optimizes DataFrame dtypes to reduce memory usage 2. Tries multiple format/compression combinations in parallel 3. Keeps only the smallest resulting file with a clean filename 4. Automatically cleans up all trial files
The following format/compression combinations are tested: - Parquet: brotli, zstd, gzip, snappy - Feather: zstd, lz4, uncompressed - CSV: gzip, bz2, xz, zstd, zip - Pickle: infer
- Parameters:
df – DataFrame to save. Will be optimized before saving.
base_path – Base path/name for output file. Any extensions will be automatically removed and replaced with the optimal format. Default: ‘data’
- Returns:
Filename of the best-compressed file (smallest size) with clean extension.
- Raises:
RuntimeError – If all save attempts fail (no files were created).
Notes
Uses parallel processing (ThreadPoolExecutor) for I/O-bound operations
All trial files except the smallest are automatically deleted
Original DataFrame is not modified (optimization works on a copy)
Final file uses standard extensions (e.g., .parquet, .csv.gz, .feather)
Any extensions in base_path are automatically stripped
If multiple formats produce the same size, the first one sorted is kept
- Performance:
Time: Depends on number of trials and DataFrame size
Memory: Peak usage ~2x DataFrame size (original + optimized copy)
Disk: Temporarily creates all trial files before cleanup
Examples
>>> df = pd.DataFrame( ... { ... "timestamp": pd.date_range("2024-01-01", periods=1000), ... "value": np.random.randn(1000), ... "category": np.random.choice(["A", "B", "C"], 1000), ... } ... )
>>> # Extension automatically added >>> best_file = save_best_compressed(df, "results/experiment_01") >>> print(best_file) 'results/experiment_01.parquet'
>>> # Extensions in input are stripped >>> best_file = save_best_compressed(df, "data.csv") >>> print(best_file) 'data.parquet' # .csv was stripped, optimal format chosen
>>> # Works with multiple extensions >>> best_file = save_best_compressed(df, "output.csv.gz") >>> print(best_file) 'output.feather' # All extensions stripped
See also
_optimize_dataframe: For details on DataFrame optimization _get_trials: For the complete list of format/compression combinations load_best_compressed: Smart loader that auto-detects format
stable_pretraining.utils.timm_to_hf_hub module#
- stable_pretraining.utils.timm_to_hf_hub.push_timm_to_hf(model_name: str, model: Module, repo_id: str, hf_token: str | None = None, private: bool = False, validate: bool = True, batch_size: int = 2, atol: float = 0.0001, rtol: float = 0.0001, device: str | None = None, strict: bool = False) str[source]#
stable_pretraining.utils.visualization module#
Visualization utilities for SSL experiments.
- stable_pretraining.utils.visualization.escape_labels(idx_or_cols)[source]#
Recursively escape all labels in a pandas Index or MultiIndex.
- stable_pretraining.utils.visualization.format_df_to_latex(df, caption=None, label=None, bold='row', na_rep='–', sort_index=False, sort_columns=False, column_format=None, position='htbp', escape_headers=True, show_percent_symbol=False, unit_annotation='caption')[source]#
Format a MultiIndex DataFrame for LaTeX export with percent formatting (no % symbol).
Escapes LaTeX special characters in all headers if escape_headers=True.
- stable_pretraining.utils.visualization.imshow_with_grid(ax, G: ndarray | Tensor, linewidth: float | None = 0.4, color: str | tuple | None = 'black', bars=[], **kwargs)[source]#
Display a matrix with grid lines overlaid.
- Parameters:
ax – Matplotlib axes to plot on
G – Matrix to display
linewidth – Width of grid lines
color – Color of grid lines
bars – List of bar specifications for highlighting regions
**kwargs – Additional arguments for imshow
- Returns:
The image object from imshow
- stable_pretraining.utils.visualization.latex_escape(s)[source]#
Escape LaTeX special characters in a string.
- stable_pretraining.utils.visualization.visualize_images_graph(x, G, zoom_on=8)[source]#
Visualize images and their similarity graph with zoom detail.
Creates a visualization showing: - A grid of sample images - The full similarity matrix - A zoomed-in view of the top-left portion of the matrix - Connection lines between the views
- Parameters:
x – List or tensor of images
G – Similarity/adjacency matrix
zoom_on – Number of rows/columns to show in zoomed view
Module contents#
Stable-pretraining utilities package.
This package provides various utilities for self-supervised learning experiments including distributed training helpers, custom autograd functions, neural network modules, stable linear algebra operations, data generation, visualization, and configuration management.
- class stable_pretraining.utils.BatchNorm1dNoBias(*args, **kwargs)[source]#
Bases:
BatchNorm1dBatchNorm1d with learnable scale but no learnable bias (center=False).
This is used in contrastive learning methods like SimCLR where the final projection layer uses batch normalization with scale (gamma) but without bias (beta). This follows the original SimCLR implementation where the bias term is removed from the final BatchNorm layer.
The bias is frozen at 0 and set to non-trainable, while the weight (scale) parameter remains learnable.
Example
```python # SimCLR-style projector projector = nn.Sequential(
nn.Linear(2048, 2048, bias=False), nn.BatchNorm1d(2048), nn.ReLU(inplace=True), nn.Linear(2048, 128, bias=False), spt.utils.nn_modules.BatchNorm1dNoBias(128), # Final layer: no bias
)#
Note
This is equivalent to TensorFlow’s BatchNorm with center=False, scale=True.
- class stable_pretraining.utils.CSVLogAutoSummarizer(agg: callable | None = None, monitor_keys: List[str] | None = None, include_globs: List[str] | None = None, exclude_globs: List[str] | None = None, max_workers: int | None = 10)[source]#
Bases:
objectAutomatically discovers and summarizes PyTorch Lightning CSVLogger metrics from Hydra multirun sweeps.
Handles arbitrary directory layouts, sparse metrics, multiple versions (preemption/resume), and aggregates config/hparams metadata into a single DataFrame. Features: - Recursively finds metrics CSVs using common patterns. - Infers run root by searching for .hydra/config.yaml (falls back to metrics parent). - Handles multiple metrics files per run with configurable grouping strategies:
latest_mtime, latest_epoch, latest_step, or merge.
Robust CSV parsing (delimiter sniffing, repeated-header cleanup, type coercion).
Sparse-metrics aware: last values are last non-NaN; best values ignore NaNs.
Loads and flattens Hydra config, overrides (list or dict), and hparams metadata.
Args: base_dir: Root directory to search for metrics files. monitor_keys: Metrics to summarize (if None, auto-infer). include_globs: Only include files matching these globs (relative to base_dir). exclude_globs: Exclude files matching these globs (e.g., ‘/checkpoints/’). forward_fill_last: If True, forward-fill the frame (after sorting) before computing last.* summaries.
- METRICS_PATTERNS = ['**/metrics.csv', '**/csv/metrics.csv', '**/*metrics*.csv']#
- collect(base_dir) DataFrame[source]#
Discover, summarize, and aggregate all runs into a DataFrame.
Args: base_dir: Root directory to search for metrics files.
- Returns:
One row per selected metrics source (per file or per run root), with flattened columns such as: - last.val_accuracy - best.val_loss, best.val_loss.step, best.val_loss.epoch - config.optimizer.lr, override.0 (or override as joined string), hparams.seed Also includes ‘metrics_path’ and ‘run_root’.
- Return type:
pd.DataFrame
- class stable_pretraining.utils.EMA(alpha: float)[source]#
Bases:
ModuleExponential Moving Average module.
Maintains an exponential moving average of input tensors.
- Parameters:
alpha – Smoothing factor between 0 and 1. 0 = no update (always return first value) 1 = no smoothing (always return current value)
- class stable_pretraining.utils.FullGatherLayer(*args, **kwargs)[source]#
Bases:
FunctionGather tensors from all process and support backward propagation.
Supports backward propagation for the gradients across processes.
- static backward(ctx, grad)[source]#
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjpfunction.)It must accept a context
ctxas the first argument, followed by as many outputs as theforward()returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_gradas a tuple of booleans representing whether each input needs gradient. E.g.,backward()will havectx.needs_input_grad[0] = Trueif the first input toforward()needs gradient computed w.r.t. the output.
- static forward(ctx, x)[source]#
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See Combined or separate forward() and setup_context() for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()staticmethod to handle setting up thectxobject.outputis the output of the forward,inputsare a Tuple of inputs to the forward.See Extending torch.autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()if they are intended to be used inbackward(equivalently,vjp) orctx.save_for_forward()if they are intended to be used for injvp.
- class stable_pretraining.utils.GDriveUploader(folder_name: str, credentials_path: str, parent_folder_id: str | None = None, callback: Callable[[str, str | None, bool], None] | None = None)[source]#
Bases:
objectBackground Google Drive uploader with automatic folder versioning.
All uploads happen in a background thread - your code never blocks! Automatically creates versioned folders (MyFolder_v2, MyFolder_v3, etc).
- Parameters:
folder_name – Name of the Drive folder to create
credentials_path – Path to service account JSON credentials
parent_folder_id – Optional parent folder ID (None = root)
callback – Optional function(file_path, file_id, success) called on completion
Example
>>> def notify(path, fid, ok): ... print(f"{'✓' if ok else '✗'} {path}") >>> >>> uploader = GDriveUploader("Backups", "creds.json", callback=notify) >>> uploader.upload_file("data.csv") # Non-blocking! >>> uploader.upload_file("logs.txt") >>> uploader.wait_for_uploads() # Wait for all to finish
- upload_file(file_path: str, custom_name: str | None = None, subfolder_id: str | None = None) None[source]#
Queue a file for background upload. Returns immediately!
- Parameters:
file_path – Path to file to upload
custom_name – Optional custom name in Drive
subfolder_id – Optional subfolder ID (defaults to main folder)
Example
>>> uploader.upload_file("video.mp4") >>> uploader.upload_file("report.pdf", custom_name="Q4_report.pdf")
- wait_for_uploads(timeout: float | None = None) bool[source]#
Wait for all queued uploads to complete.
- Parameters:
timeout – Max seconds to wait (None = wait forever)
- Returns:
True if completed, False if timeout/error
Example
>>> uploader.upload_file("file1.txt") >>> uploader.upload_file("file2.txt") >>> uploader.wait_for_uploads(timeout=300) # Wait max 5 min
- class stable_pretraining.utils.ImageToVideoEncoder(encoder: Module)[source]#
Bases:
ModuleWrapper to apply an image encoder to video data by processing each frame independently.
This module takes video data with shape (batch, time, channel, height, width) and applies an image encoder to each frame, returning the encoded features.
- Parameters:
encoder (torch.nn.Module) – The image encoder module to apply to each frame.
- forward(video)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.utils.L2Norm(*args, **kwargs)[source]#
Bases:
ModuleL2 normalization layer that normalizes input to unit length.
Normalizes the input tensor along the last dimension to have unit L2 norm. Commonly used in DINO before the prototypes layer.
Example
```python projector = nn.Sequential(
nn.Linear(512, 2048), nn.GELU(), nn.Linear(2048, 256), spt.utils.nn_modules.L2Norm(), # Normalize to unit length nn.Linear(256, 4096, bias=False), # Prototypes
)#
- class stable_pretraining.utils.Normalize(*args, **kwargs)[source]#
Bases:
ModuleNormalize tensor and scale by square root of number of elements.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.utils.OrderedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#
Bases:
ModuleA queue that maintains insertion order of elements.
Similar to UnsortedQueue but tracks the order in which items were inserted, allowing retrieval in the original insertion order even after wraparound.
- Parameters:
max_length – Maximum number of elements to store in the queue
shape – Shape of each element (excluding batch dimension). Can be int or tuple
dtype – Data type of the tensors to store
- append(item)[source]#
Append item(s) to the queue with order tracking.
- Parameters:
item – Tensor to append. First dimension is batch size.
- Returns:
Current contents of the queue in insertion order
- get()[source]#
Get current contents sorted by insertion order.
- Returns:
Tensor containing items sorted by their original insertion order
- get_unsorted()[source]#
Get current contents without sorting (like UnsortedQueue).
- Returns:
Tensor containing items in buffer order
- load_state_dict(state_dict, strict=True, assign=False)[source]#
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield ofParameterfor which the value from the module is preserved. Default:False
- Returns:
missing_keysis a list of str containing any keys that are expectedby this module but missing from the provided
state_dict.
unexpected_keysis a list of str containing the keys that are notexpected by this module but present in the provided
state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- class stable_pretraining.utils.UnsortedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#
Bases:
ModuleA queue data structure that stores tensors with a maximum length.
This module implements a circular buffer that stores tensors up to a maximum length. When the queue is full, new items overwrite the oldest ones.
- Parameters:
max_length – Maximum number of elements to store in the queue
shape – Shape of each element (excluding batch dimension). Can be int or tuple
dtype – Data type of the tensors to store
- stable_pretraining.utils.adapt_resnet_for_lowres(model)[source]#
Adapt a ResNet model for low resolution images.
Modifies the first convolution layer to use 3x3 kernels with stride 1 and removes the max pooling layer.
- Parameters:
model – ResNet model to adapt
- Returns:
Modified model
- stable_pretraining.utils.all_gather(tensor, *args, **kwargs)[source]#
Gather tensors from all processes.
- Parameters:
tensor – The tensor to gather
*args – Additional arguments for all_gather
**kwargs – Additional keyword arguments for all_gather
- Returns:
Tuple containing the gathered tensors
- stable_pretraining.utils.all_reduce(tensor, *args, **kwargs)[source]#
Reduce tensors across all processes.
- Parameters:
tensor – The tensor to reduce
*args – Additional arguments for all_reduce
**kwargs – Additional keyword arguments for all_reduce
- Returns:
The reduced tensor
- stable_pretraining.utils.broadcast_param_to_list(param: Any, target_length: int, param_name: str) List[Any][source]#
Broadcast a parameter value to create a list of specified length.
This function handles the common pattern of accepting either: - None: creates a list of None values - A single value: broadcasts to all positions - A single-element list/tuple: broadcasts the element to all positions - A list/tuple of correct length: returns as-is
- Parameters:
param – The parameter to broadcast (can be None, single value, or list/tuple)
target_length – The desired length of the output list
param_name – Name of the parameter for error messages
- Returns:
List of values with length matching target_length
- Raises:
ValueError – If param is a list/tuple with length > 1 that doesn’t match target_length
Examples
>>> broadcast_param_to_list(None, 3, "dims") [None, None, None] >>> broadcast_param_to_list(5, 3, "dims") [5, 5, 5] >>> broadcast_param_to_list([5], 3, "dims") [5, 5, 5] >>> broadcast_param_to_list([1, 2, 3], 3, "dims") [1, 2, 3]
- stable_pretraining.utils.compute_pairwise_distances(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean') Tensor[source]#
Compute pairwise distances between two sets of vectors.
- Parameters:
x – Tensor of shape (n, d) containing n vectors of dimension d
y – Tensor of shape (m, d) containing m vectors of dimension d
metric – Distance metric to use. Options: - “euclidean”: L2 distance - “squared_euclidean”: Squared L2 distance - “cosine”: Cosine distance (1 - cosine_similarity) - “manhattan”: L1 distance
- Returns:
Distance matrix of shape (n, m) where element (i, j) is the distance between x[i] and y[j]
- stable_pretraining.utils.compute_pairwise_distances_chunked(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean', chunk_size: int = 1024) Tensor[source]#
Memory-efficient computation of pairwise distances using chunking.
- Parameters:
x – Tensor of shape (n, d) containing n vectors of dimension d
y – Tensor of shape (m, d) containing m vectors of dimension d
metric – Distance metric to use
chunk_size – Process y in chunks of this size to save memory
- Returns:
Distance matrix of shape (n, m)
- stable_pretraining.utils.detach_tensors(obj: Any) Any[source]#
Recursively traverse an object and return an equivalent structure with all torch tensors detached.
Preserves structure, types, and shared references.
Handles cycles and arbitrary Python objects (including __dict__ and __slots__).
Does not mutate the input; only rebuilds containers if needed.
torch.nn.Parameter is replaced with a detached Tensor (not Parameter).
Optionally supports attrs classes if ‘attr’ is installed.
- Parameters:
obj – The input object (can be arbitrarily nested).
- Returns:
A new object with all torch tensors detached, or the original object if no tensors found.
- Performance notes:
Uses memoization to avoid redundant work and preserve shared/cyclic structure.
Avoids unnecessary copies: unchanged subtrees are returned as-is (same id).
Shallow-copies objects with __dict__ or __slots__ (does not call __init__).
- stable_pretraining.utils.dict_values(**kwargs)[source]#
Convert keyword arguments to a list of values.
- Returns:
List of values from the provided keyword arguments
- stable_pretraining.utils.execute_from_config(manager, cfg)[source]#
Execute a function with support for submitit job submission.
If submitit configuration is present, submits the job to a cluster. Otherwise executes locally.
- Parameters:
manager – Function or callable to execute
cfg – Configuration dictionary
- Returns:
Result of the executed function
- stable_pretraining.utils.find_module(model: Module, module: Module)[source]#
Find all instances of a module type in a model.
- Parameters:
model – Model to search in
module – Module class to search for
- Returns:
Tuple of (names, modules) where names are the module paths and modules are the actual module instances
- stable_pretraining.utils.format_df_to_latex(df, caption=None, label=None, bold='row', na_rep='–', sort_index=False, sort_columns=False, column_format=None, position='htbp', escape_headers=True, show_percent_symbol=False, unit_annotation='caption')[source]#
Format a MultiIndex DataFrame for LaTeX export with percent formatting (no % symbol).
Escapes LaTeX special characters in all headers if escape_headers=True.
- stable_pretraining.utils.generate_dae_samples(x, n, eps, num_workers=10)[source]#
Generate samples for Denoising Autoencoder (DAE) training.
- Parameters:
x – List of input images
n – Number of noisy versions per image
eps – Noise level (variance)
num_workers – Number of parallel workers
- Returns:
Tuple of (noisy_images, similarity_matrix)
- stable_pretraining.utils.generate_dm_samples(x, n, betas, i, num_workers=10)[source]#
Generate samples for Diffusion Model training.
- Parameters:
x – List of input images
n – Number of noisy versions per timestep
betas – Noise schedule beta values
i – Timestep indices to use
num_workers – Number of parallel workers
- Returns:
Tuple of (noisy_images, similarity_matrix)
- stable_pretraining.utils.generate_ssl_samples(x, n, num_workers=10)[source]#
Generate augmented samples for self-supervised learning.
Creates n augmented versions of each image.
- Parameters:
x – List of input images
n – Number of augmented versions per image
num_workers – Number of parallel workers
- Returns:
Tuple of (augmented_images, similarity_matrix)
- stable_pretraining.utils.generate_sup_samples(x, y, n, num_workers=10)[source]#
Generate samples for supervised learning with class structure.
Only includes classes with at least n samples.
- Parameters:
x – List of input images
y – Class labels
n – Minimum samples per class
num_workers – Number of parallel workers
- Returns:
Tuple of (processed_images, class_similarity_matrix)
- stable_pretraining.utils.get_data_from_batch_or_outputs(key: Iterable[str] | str, batch: Dict[str, Any], outputs: Dict[str, Any] | None = None, caller_name: str = 'Callback') Any | None[source]#
Get data from either outputs or batch dictionary.
In PyTorch Lightning, the outputs parameter in callbacks contains the return value from training_step/validation_step, while batch contains the original input. Since forward methods may modify batch in-place but Lightning creates a copy for outputs, we need to check both.
- Parameters:
key – The key(s) to look for in the dictionaries
batch – The original batch dictionary
outputs – The outputs dictionary from training/validation step
caller_name – Name of the calling function/class for logging
- Returns:
The data associated with the key, or None if not found
- stable_pretraining.utils.get_required_fn_parameters(fn)[source]#
Get the list of required parameters for a function.
- Parameters:
fn – The function to inspect
- Returns:
List of parameter names that don’t have default values
- stable_pretraining.utils.is_dist_avail_and_initialized()[source]#
Check if distributed training is available and initialized.
- Returns:
True if distributed is available and initialized, False otherwise
- Return type:
- stable_pretraining.utils.load_hparams_from_ckpt(path: str) Any[source]#
Loads a checkpoint safely in both distributed and non-distributed settings.
If torch.distributed is initialized, only src_rank loads from disk, then broadcasts to all other ranks.
If not, loads directly from disk.
- Parameters:
path – Path to checkpoint file.
- Returns:
The loaded hparams.
- stable_pretraining.utils.replace_module(model, replacement_mapping)[source]#
Replace modules in a model based on a mapping function.
- Parameters:
model – PyTorch model to modify
replacement_mapping – Function that takes (name, module) and returns the replacement module
- Returns:
Modified model
- stable_pretraining.utils.rgetattr(obj, attr)[source]#
Recursively get an attribute using dot notation.
- Parameters:
obj – Object to get attribute from
attr – Attribute path (e.g., “module.layer.weight”)
- Returns:
The requested attribute value
- stable_pretraining.utils.rsetattr(obj, attr, val)[source]#
Recursively set an attribute using dot notation.
- Parameters:
obj – Object to set attribute on
attr – Attribute path (e.g., “module.layer.weight”)
val – Value to set
- stable_pretraining.utils.with_hf_retry_ratelimit(func, *args, delay=10, max_attempts=100, **kwargs)[source]#
Calls the given function with retry logic for HTTP 429 (Too Many Requests) errors.
This function attempts to call
func(*args, **kwargs). If a rate-limiting error (HTTP 429) is encountered—detected via exception type, status code, or error message—it will wait for the duration specified by the HTTPRetry-Afterheader (if present), or fall back to thedelayparameter, and then retry. Retries continue up tomax_attemptstimes. Non-429 errors are immediately re-raised. If all attempts fail due to 429, the last exception is raised.- Exceptions handled:
huggingface_hub.utils.HfHubHTTPError
requests.exceptions.HTTPError
OSError
429 detection is performed by checking the exception’s
response.status_code(if available) or by searching for ‘429’ or ‘Too Many Requests’ in the exception message.- Parameters:
func (callable) – The function to call.
*args – Positional arguments to pass to
func.delay (int, optional) – Default wait time (in seconds) between retries if
Retry-Afteris not provided. Defaults to 10.max_attempts (int, optional) – Maximum number of attempts before giving up. Defaults to 100.
**kwargs – Keyword arguments to pass to
func.
- Returns:
The return value of
func(*args, **kwargs)if successful.- Raises:
Exception – The original exception if a non-429 error occurs, or if all attempts fail.
Example
>>> from transformers import AutoModel >>> model = with_hf_retry_ratelimit( ... AutoModel.from_pretrained, ... "facebook/ijepa_vith14_1k", ... delay=10, ... max_attempts=5, ... )