stable_pretraining.utils package

Contents

stable_pretraining.utils package#

Submodules#

stable_pretraining.utils.batch_utils module#

Utility functions for handling batch and outputs dictionaries in callbacks.

stable_pretraining.utils.batch_utils.detach_tensors(obj: Any) Any[source]#

Recursively traverse an object and return an equivalent structure with all torch tensors detached.

  • Preserves structure, types, and shared references.

  • Handles cycles and arbitrary Python objects (including __dict__ and __slots__).

  • Does not mutate the input; only rebuilds containers if needed.

  • torch.nn.Parameter is replaced with a detached Tensor (not Parameter).

  • Optionally supports attrs classes if ‘attr’ is installed.

Parameters:

obj – The input object (can be arbitrarily nested).

Returns:

A new object with all torch tensors detached, or the original object if no tensors found.

Performance notes:
  • Uses memoization to avoid redundant work and preserve shared/cyclic structure.

  • Avoids unnecessary copies: unchanged subtrees are returned as-is (same id).

  • Shallow-copies objects with __dict__ or __slots__ (does not call __init__).

stable_pretraining.utils.batch_utils.get_data_from_batch_or_outputs(key: Iterable[str] | str, batch: Dict[str, Any], outputs: Dict[str, Any] | None = None, caller_name: str = 'Callback') Any | None[source]#

Get data from either outputs or batch dictionary.

In PyTorch Lightning, the outputs parameter in callbacks contains the return value from training_step/validation_step, while batch contains the original input. Since forward methods may modify batch in-place but Lightning creates a copy for outputs, we need to check both.

Parameters:
  • key – The key(s) to look for in the dictionaries

  • batch – The original batch dictionary

  • outputs – The outputs dictionary from training/validation step

  • caller_name – Name of the calling function/class for logging

Returns:

The data associated with the key, or None if not found

stable_pretraining.utils.config module#

Configuration utilities and model manipulation helpers.

stable_pretraining.utils.config.adapt_resnet_for_lowres(model)[source]#

Adapt a ResNet model for low resolution images.

Modifies the first convolution layer to use 3x3 kernels with stride 1 and removes the max pooling layer.

Parameters:

model – ResNet model to adapt

Returns:

Modified model

stable_pretraining.utils.config.execute_from_config(manager, cfg)[source]#

Execute a function with support for submitit job submission.

If submitit configuration is present, submits the job to a cluster. Otherwise executes locally.

Parameters:
  • manager – Function or callable to execute

  • cfg – Configuration dictionary

Returns:

Result of the executed function

stable_pretraining.utils.config.find_module(model: Module, module: Module)[source]#

Find all instances of a module type in a model.

Parameters:
  • model – Model to search in

  • module – Module class to search for

Returns:

Tuple of (names, modules) where names are the module paths and modules are the actual module instances

stable_pretraining.utils.config.is_dist() bool[source]#

Returns True if torch.distributed is available and initialized.

stable_pretraining.utils.config.load_hparams_from_ckpt(path: str) Any[source]#

Loads a checkpoint safely in both distributed and non-distributed settings.

  • If torch.distributed is initialized, only src_rank loads from disk, then broadcasts to all other ranks.

  • If not, loads directly from disk.

Parameters:

path – Path to checkpoint file.

Returns:

The loaded hparams.

stable_pretraining.utils.config.replace_module(model, replacement_mapping)[source]#

Replace modules in a model based on a mapping function.

Parameters:
  • model – PyTorch model to modify

  • replacement_mapping – Function that takes (name, module) and returns the replacement module

Returns:

Modified model

stable_pretraining.utils.config.rgetattr(obj, attr)[source]#

Recursively get an attribute using dot notation.

Parameters:
  • obj – Object to get attribute from

  • attr – Attribute path (e.g., “module.layer.weight”)

Returns:

The requested attribute value

stable_pretraining.utils.config.rsetattr(obj, attr, val)[source]#

Recursively set an attribute using dot notation.

Parameters:
  • obj – Object to set attribute on

  • attr – Attribute path (e.g., “module.layer.weight”)

  • val – Value to set

stable_pretraining.utils.data_generation module#

Sample generation utilities for SSL experiments.

stable_pretraining.utils.data_generation.generate_dae_samples(x, n, eps, num_workers=10)[source]#

Generate samples for Denoising Autoencoder (DAE) training.

Parameters:
  • x – List of input images

  • n – Number of noisy versions per image

  • eps – Noise level (variance)

  • num_workers – Number of parallel workers

Returns:

Tuple of (noisy_images, similarity_matrix)

stable_pretraining.utils.data_generation.generate_dm_samples(x, n, betas, i, num_workers=10)[source]#

Generate samples for Diffusion Model training.

Parameters:
  • x – List of input images

  • n – Number of noisy versions per timestep

  • betas – Noise schedule beta values

  • i – Timestep indices to use

  • num_workers – Number of parallel workers

Returns:

Tuple of (noisy_images, similarity_matrix)

stable_pretraining.utils.data_generation.generate_ssl_samples(x, n, num_workers=10)[source]#

Generate augmented samples for self-supervised learning.

Creates n augmented versions of each image.

Parameters:
  • x – List of input images

  • n – Number of augmented versions per image

  • num_workers – Number of parallel workers

Returns:

Tuple of (augmented_images, similarity_matrix)

stable_pretraining.utils.data_generation.generate_sup_samples(x, y, n, num_workers=10)[source]#

Generate samples for supervised learning with class structure.

Only includes classes with at least n samples.

Parameters:
  • x – List of input images

  • y – Class labels

  • n – Minimum samples per class

  • num_workers – Number of parallel workers

Returns:

Tuple of (processed_images, class_similarity_matrix)

stable_pretraining.utils.distance_metrics module#

Distance metric functions for computing pairwise distances between tensors.

stable_pretraining.utils.distance_metrics.compute_pairwise_distances(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean') Tensor[source]#

Compute pairwise distances between two sets of vectors.

Parameters:
  • x – Tensor of shape (n, d) containing n vectors of dimension d

  • y – Tensor of shape (m, d) containing m vectors of dimension d

  • metric – Distance metric to use. Options: - “euclidean”: L2 distance - “squared_euclidean”: Squared L2 distance - “cosine”: Cosine distance (1 - cosine_similarity) - “manhattan”: L1 distance

Returns:

Distance matrix of shape (n, m) where element (i, j) is the distance between x[i] and y[j]

stable_pretraining.utils.distance_metrics.compute_pairwise_distances_chunked(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean', chunk_size: int = 1024) Tensor[source]#

Memory-efficient computation of pairwise distances using chunking.

Parameters:
  • x – Tensor of shape (n, d) containing n vectors of dimension d

  • y – Tensor of shape (m, d) containing m vectors of dimension d

  • metric – Distance metric to use

  • chunk_size – Process y in chunks of this size to save memory

Returns:

Distance matrix of shape (n, m)

stable_pretraining.utils.distributed module#

Distributed training utilities.

class stable_pretraining.utils.distributed.FullGatherLayer(*args, **kwargs)[source]#

Bases: Function

Gather tensors from all process and support backward propagation.

Supports backward propagation for the gradients across processes.

static backward(ctx, grad)[source]#

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, x)[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

stable_pretraining.utils.distributed.all_gather(tensor, *args, **kwargs)[source]#

Gather tensors from all processes.

Parameters:
  • tensor – The tensor to gather

  • *args – Additional arguments for all_gather

  • **kwargs – Additional keyword arguments for all_gather

Returns:

Tuple containing the gathered tensors

stable_pretraining.utils.distributed.all_reduce(tensor, *args, **kwargs)[source]#

Reduce tensors across all processes.

Parameters:
  • tensor – The tensor to reduce

  • *args – Additional arguments for all_reduce

  • **kwargs – Additional keyword arguments for all_reduce

Returns:

The reduced tensor

stable_pretraining.utils.distributed.is_dist_avail_and_initialized()[source]#

Check if distributed training is available and initialized.

Returns:

True if distributed is available and initialized, False otherwise

Return type:

bool

stable_pretraining.utils.error_handling module#

stable_pretraining.utils.error_handling.with_hf_retry_ratelimit(func, *args, delay=10, max_attempts=100, **kwargs)[source]#

Calls the given function with retry logic for HTTP 429 (Too Many Requests) errors.

This function attempts to call func(*args, **kwargs). If a rate-limiting error (HTTP 429) is encountered—detected via exception type, status code, or error message—it will wait for the duration specified by the HTTP Retry-After header (if present), or fall back to the delay parameter, and then retry. Retries continue up to max_attempts times. Non-429 errors are immediately re-raised. If all attempts fail due to 429, the last exception is raised.

Exceptions handled:
  • huggingface_hub.utils.HfHubHTTPError

  • requests.exceptions.HTTPError

  • OSError

429 detection is performed by checking the exception’s response.status_code (if available) or by searching for ‘429’ or ‘Too Many Requests’ in the exception message.

Parameters:
  • func (callable) – The function to call.

  • *args – Positional arguments to pass to func.

  • delay (int, optional) – Default wait time (in seconds) between retries if Retry-After is not provided. Defaults to 10.

  • max_attempts (int, optional) – Maximum number of attempts before giving up. Defaults to 100.

  • **kwargs – Keyword arguments to pass to func.

Returns:

The return value of func(*args, **kwargs) if successful.

Raises:

Exception – The original exception if a non-429 error occurs, or if all attempts fail.

Example

>>> from transformers import AutoModel
>>> model = with_hf_retry_ratelimit(
...     AutoModel.from_pretrained,
...     "facebook/ijepa_vith14_1k",
...     delay=10,
...     max_attempts=5,
... )

stable_pretraining.utils.gdrive_utils module#

Google Drive Background Uploader.

Upload files to Google Drive without blocking your application. Features: auto-versioning, background processing, comprehensive logging.

Example

>>> uploader = GDriveUploader("MyProject", "credentials.json")
>>> uploader.upload_file("large_file.mp4")  # Returns immediately!
>>> uploader.wait_for_uploads()  # Optional: wait for completion
class stable_pretraining.utils.gdrive_utils.GDriveUploader(folder_name: str, credentials_path: str, parent_folder_id: str | None = None, callback: Callable[[str, str | None, bool], None] | None = None)[source]#

Bases: object

Background Google Drive uploader with automatic folder versioning.

All uploads happen in a background thread - your code never blocks! Automatically creates versioned folders (MyFolder_v2, MyFolder_v3, etc).

Parameters:
  • folder_name – Name of the Drive folder to create

  • credentials_path – Path to service account JSON credentials

  • parent_folder_id – Optional parent folder ID (None = root)

  • callback – Optional function(file_path, file_id, success) called on completion

Example

>>> def notify(path, fid, ok):
...     print(f"{'✓' if ok else '✗'} {path}")
>>>
>>> uploader = GDriveUploader("Backups", "creds.json", callback=notify)
>>> uploader.upload_file("data.csv")  # Non-blocking!
>>> uploader.upload_file("logs.txt")
>>> uploader.wait_for_uploads()  # Wait for all to finish
get_folder_id() str[source]#

Get the Drive folder ID.

get_folder_url() str[source]#

Get the Drive folder URL.

get_queue_size() int[source]#

Get number of pending uploads.

upload_file(file_path: str, custom_name: str | None = None, subfolder_id: str | None = None) None[source]#

Queue a file for background upload. Returns immediately!

Parameters:
  • file_path – Path to file to upload

  • custom_name – Optional custom name in Drive

  • subfolder_id – Optional subfolder ID (defaults to main folder)

Example

>>> uploader.upload_file("video.mp4")
>>> uploader.upload_file("report.pdf", custom_name="Q4_report.pdf")
wait_for_uploads(timeout: float | None = None) bool[source]#

Wait for all queued uploads to complete.

Parameters:

timeout – Max seconds to wait (None = wait forever)

Returns:

True if completed, False if timeout/error

Example

>>> uploader.upload_file("file1.txt")
>>> uploader.upload_file("file2.txt")
>>> uploader.wait_for_uploads(timeout=300)  # Wait max 5 min

stable_pretraining.utils.inspection_utils module#

Function inspection and general helper utilities.

stable_pretraining.utils.inspection_utils.broadcast_param_to_list(param: Any, target_length: int, param_name: str) List[Any][source]#

Broadcast a parameter value to create a list of specified length.

This function handles the common pattern of accepting either: - None: creates a list of None values - A single value: broadcasts to all positions - A single-element list/tuple: broadcasts the element to all positions - A list/tuple of correct length: returns as-is

Parameters:
  • param – The parameter to broadcast (can be None, single value, or list/tuple)

  • target_length – The desired length of the output list

  • param_name – Name of the parameter for error messages

Returns:

List of values with length matching target_length

Raises:

ValueError – If param is a list/tuple with length > 1 that doesn’t match target_length

Examples

>>> broadcast_param_to_list(None, 3, "dims")
[None, None, None]
>>> broadcast_param_to_list(5, 3, "dims")
[5, 5, 5]
>>> broadcast_param_to_list([5], 3, "dims")
[5, 5, 5]
>>> broadcast_param_to_list([1, 2, 3], 3, "dims")
[1, 2, 3]
stable_pretraining.utils.inspection_utils.dict_values(**kwargs)[source]#

Convert keyword arguments to a list of values.

Returns:

List of values from the provided keyword arguments

stable_pretraining.utils.inspection_utils.get_required_fn_parameters(fn)[source]#

Get the list of required parameters for a function.

Parameters:

fn – The function to inspect

Returns:

List of parameter names that don’t have default values

stable_pretraining.utils.lightning_patch module#

Monkey-patch for PyTorch Lightning to support manual optimization with Trainer parameters.

This patch modifies Lightning’s validation to transfer gradient_clip_val and accumulate_grad_batches to alternative attributes instead of raising errors.

stable_pretraining.utils.lightning_patch.apply_manual_optimization_patch()[source]#

Apply the monkey-patch to Lightning’s manual optimization validation.

This patch modifies the __verify_manual_optimization_support function to: 1. Transfer gradient_clip_val to gradient_clip_val_ 2. Transfer accumulate_grad_batches to accumulate_grad_batches_ 3. Clear the original values to avoid Lightning’s error

This allows users to use standard Trainer parameters even with manual optimization.

stable_pretraining.utils.lightning_patch.restore_original_validation()[source]#

Restore the original Lightning validation function (for testing/debugging).

stable_pretraining.utils.log_reader module#

Unified log reader for local and wandb logs.

class stable_pretraining.utils.log_reader.LocalLogReader(num_workers: int = 8)[source]#

Bases: LogReader

Reader for local jsonl log files.

read(path: str | Path) List[Dict[str, Any]][source]#

Load values from a single run directory.

Parameters:

path – Path to the run directory

Returns:

List of log entries

read_config(path: str | Path) Dict[str, Any][source]#

Load config from a single run directory.

Parameters:

path – Path to the run directory

Returns:

Configuration dictionary

read_project(folder: str | Path) Tuple[DataFrame, List[List[Dict[str, Any]]]][source]#

Load configs and values from all runs in a folder.

Parameters:

folder – Path to the project folder

Returns:

Tuple of (configs DataFrame, list of values for each run)

class stable_pretraining.utils.log_reader.LogReader[source]#

Bases: ABC

Abstract base class for log readers.

abstractmethod read(*args, **kwargs)[source]#

Read logs from source.

class stable_pretraining.utils.log_reader.TableFormatter[source]#

Bases: object

Format experiment results as tables for analysis.

static create_table(dfs: Dict[str, DataFrame], configs: Dict[str, Any], value: str, row: str, column: str, agg: Callable, filters: Dict[str, Any] | None = None) DataFrame[source]#

Format a pandas DataFrame as a table given the user args.

Parameters:
  • dfs – Dictionary of DataFrames (one per run)

  • configs – Dictionary of configs (one per run)

  • value – Name of the column in dfs to use as values

  • row – Name of the column in configs to use as row

  • column – Name of the column in configs to use as column

  • agg – Aggregator function if many values are present

  • filters – Optional filters to apply to the data

Returns:

Formatted table as DataFrame

static tabulate_runs(configs: DataFrame, runs: List[Any], value: str, ignore: List[str] | None = None) DataFrame[source]#

Create a pivot table from configs and runs for a specific value.

Parameters:
  • configs – DataFrame of run configurations

  • runs – List of run data

  • value – Value to extract from runs

  • ignore – Columns to ignore

Returns:

Pivot table as DataFrame

class stable_pretraining.utils.log_reader.WandbLogReader(num_workers: int = 10)[source]#

Bases: LogReader

Reader for Weights & Biases logs.

read(entity: str, project: str, run_id: str, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, _tqdm_disable: bool = False) Tuple[DataFrame, Dict[str, Any]][source]#

Download data for a single wandb run.

Parameters:
  • entity – Wandb entity name

  • project – Wandb project name

  • run_id – Run ID

  • min_step – Minimum step to download

  • max_step – Maximum step to download (-1 for all)

  • keys – Specific keys to download

  • _tqdm_disable – Whether to disable tqdm progress bar

Returns:

Tuple of (data DataFrame, config dict)

read_project(entity: str, project: str, filters: Dict[str, Any] | None = None, order: str = '+created_at', per_page: int = 50, include_sweeps: bool = True, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, return_summary: bool = False) Tuple[Dict[str, DataFrame], Dict[str, Dict]] | DataFrame[source]#

Download configs and data from a wandb project.

Parameters:
  • entity – Wandb entity name

  • project – Wandb project name

  • filters – Optional filters for runs

  • order – Sort order for runs

  • per_page – Number of runs per page

  • include_sweeps – Whether to include sweep runs

  • min_step – Minimum step to download

  • max_step – Maximum step to download

  • keys – Specific keys to download

  • return_summary – If True, return only summary DataFrame

Returns:

Tuple of (dfs dict, configs dict) If return_summary is True: Summary DataFrame

Return type:

If return_summary is False

stable_pretraining.utils.log_reader.alphanum_key(key: str) List[int | str][source]#

Convert a string to a list of mixed numbers and strings for natural sorting.

stable_pretraining.utils.log_reader.create_results_table(dfs: ~typing.Dict[str, ~pandas.core.frame.DataFrame], configs: ~typing.Dict[str, ~typing.Any], value: str, row: str, column: str, agg: ~typing.Callable = <function mean>, filters: ~typing.Dict[str, ~typing.Any] | None = None) DataFrame[source]#

Convenience function to create a results table.

Parameters:
  • dfs – Dictionary of DataFrames (one per run)

  • configs – Dictionary of configs (one per run)

  • value – Name of the column in dfs to use as values

  • row – Name of the column in configs to use as row

  • column – Name of the column in configs to use as column

  • agg – Aggregator function (default: mean)

  • filters – Optional filters to apply

Returns:

Formatted table as DataFrame

stable_pretraining.utils.log_reader.flatten_config(config: Dict[str, Any]) Dict[str, Any][source]#

Flatten nested config dictionaries into a single level.

Parameters:

config – Nested configuration dictionary

Returns:

Flattened configuration dictionary

stable_pretraining.utils.log_reader.natural_sort(values: List[str]) List[str][source]#

Sort a list of strings naturally (handling numbers properly).

stable_pretraining.utils.log_reader.read_local_logs(path: str | Path, num_workers: int = 8) List[Dict[str, Any]][source]#

Convenience function to read local logs.

Parameters:
  • path – Path to the run directory

  • num_workers – Number of parallel workers

Returns:

List of log entries

stable_pretraining.utils.log_reader.read_local_project(folder: str | Path, num_workers: int = 8) Tuple[DataFrame, List[List[Dict[str, Any]]]][source]#

Convenience function to read a local project.

Parameters:
  • folder – Path to the project folder

  • num_workers – Number of parallel workers

Returns:

Tuple of (configs DataFrame, list of values)

stable_pretraining.utils.log_reader.read_wandb_project(entity: str, project: str, filters: Dict[str, Any] | None = None, order: str = '+created_at', per_page: int = 50, include_sweeps: bool = True, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, num_workers: int = 10, return_summary: bool = False) Tuple[Dict[str, DataFrame], Dict[str, Dict]] | DataFrame[source]#

Convenience function to read a wandb project.

Parameters:
  • entity – Wandb entity name

  • project – Wandb project name

  • filters – Optional filters for runs

  • order – Sort order for runs

  • per_page – Number of runs per page

  • include_sweeps – Whether to include sweep runs

  • min_step – Minimum step to download

  • max_step – Maximum step to download

  • keys – Specific keys to download

  • num_workers – Number of parallel workers

  • return_summary – If True, return only summary DataFrame

Returns:

Tuple of (dfs dict, configs dict) If return_summary is True: Summary DataFrame

Return type:

If return_summary is False

stable_pretraining.utils.log_reader.read_wandb_run(entity: str, project: str, run_id: str, min_step: int = 0, max_step: int = -1, keys: List[str] | None = None, num_workers: int = 10) Tuple[DataFrame, Dict[str, Any]][source]#

Convenience function to read a wandb run.

Parameters:
  • entity – Wandb entity name

  • project – Wandb project name

  • run_id – Run ID

  • min_step – Minimum step to download

  • max_step – Maximum step to download

  • keys – Specific keys to download

  • num_workers – Number of parallel workers

Returns:

Tuple of (data DataFrame, config dict)

stable_pretraining.utils.nn_modules module#

Neural network modules and utilities.

class stable_pretraining.utils.nn_modules.BatchNorm1dNoBias(*args, **kwargs)[source]#

Bases: BatchNorm1d

BatchNorm1d with learnable scale but no learnable bias (center=False).

This is used in contrastive learning methods like SimCLR where the final projection layer uses batch normalization with scale (gamma) but without bias (beta). This follows the original SimCLR implementation where the bias term is removed from the final BatchNorm layer.

The bias is frozen at 0 and set to non-trainable, while the weight (scale) parameter remains learnable.

Example

```python # SimCLR-style projector projector = nn.Sequential(

nn.Linear(2048, 2048, bias=False), nn.BatchNorm1d(2048), nn.ReLU(inplace=True), nn.Linear(2048, 128, bias=False), spt.utils.nn_modules.BatchNorm1dNoBias(128), # Final layer: no bias

)#

Note

This is equivalent to TensorFlow’s BatchNorm with center=False, scale=True.

class stable_pretraining.utils.nn_modules.EMA(alpha: float)[source]#

Bases: Module

Exponential Moving Average module.

Maintains an exponential moving average of input tensors.

Parameters:

alpha – Smoothing factor between 0 and 1. 0 = no update (always return first value) 1 = no smoothing (always return current value)

forward(item)[source]#

Update EMA and return smoothed value.

Parameters:

item – New tensor to incorporate into the average

Returns:

Exponentially smoothed tensor

class stable_pretraining.utils.nn_modules.ImageToVideoEncoder(encoder: Module)[source]#

Bases: Module

Wrapper to apply an image encoder to video data by processing each frame independently.

This module takes video data with shape (batch, time, channel, height, width) and applies an image encoder to each frame, returning the encoded features.

Parameters:

encoder (torch.nn.Module) – The image encoder module to apply to each frame.

forward(video)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class stable_pretraining.utils.nn_modules.L2Norm(*args, **kwargs)[source]#

Bases: Module

L2 normalization layer that normalizes input to unit length.

Normalizes the input tensor along the last dimension to have unit L2 norm. Commonly used in DINO before the prototypes layer.

Example

```python projector = nn.Sequential(

nn.Linear(512, 2048), nn.GELU(), nn.Linear(2048, 256), spt.utils.nn_modules.L2Norm(), # Normalize to unit length nn.Linear(256, 4096, bias=False), # Prototypes

)#

forward(x: Tensor) Tensor[source]#

Normalize input to unit L2 norm.

Parameters:

x – Input tensor […, D]

Returns:

L2-normalized tensor […, D] where each D-dimensional vector has unit length

class stable_pretraining.utils.nn_modules.Normalize(*args, **kwargs)[source]#

Bases: Module

Normalize tensor and scale by square root of number of elements.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class stable_pretraining.utils.nn_modules.OrderedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#

Bases: Module

A queue that maintains insertion order of elements.

Similar to UnsortedQueue but tracks the order in which items were inserted, allowing retrieval in the original insertion order even after wraparound.

Parameters:
  • max_length – Maximum number of elements to store in the queue

  • shape – Shape of each element (excluding batch dimension). Can be int or tuple

  • dtype – Data type of the tensors to store

append(item)[source]#

Append item(s) to the queue with order tracking.

Parameters:

item – Tensor to append. First dimension is batch size.

Returns:

Current contents of the queue in insertion order

get()[source]#

Get current contents sorted by insertion order.

Returns:

Tensor containing items sorted by their original insertion order

get_unsorted()[source]#

Get current contents without sorting (like UnsortedQueue).

Returns:

Tensor containing items in buffer order

load_state_dict(state_dict, strict=True, assign=False)[source]#

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – When set to False, the properties of the tensors in the current module are preserved whereas setting it to True preserves properties of the Tensors in the state dict. The only exception is the requires_grad field of Parameter for which the value from the module is preserved. Default: False

Returns:

  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

class stable_pretraining.utils.nn_modules.UnsortedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#

Bases: Module

A queue data structure that stores tensors with a maximum length.

This module implements a circular buffer that stores tensors up to a maximum length. When the queue is full, new items overwrite the oldest ones.

Parameters:
  • max_length – Maximum number of elements to store in the queue

  • shape – Shape of each element (excluding batch dimension). Can be int or tuple

  • dtype – Data type of the tensors to store

append(item)[source]#

Append item(s) to the queue.

Parameters:

item – Tensor to append. First dimension is batch size.

Returns:

Current contents of the queue

get()[source]#

Get current contents of the queue.

Returns:

Tensor containing all items currently in the queue

stable_pretraining.utils.read_csv_logger module#

class stable_pretraining.utils.read_csv_logger.CSVLogAutoSummarizer(agg: callable | None = None, monitor_keys: List[str] | None = None, include_globs: List[str] | None = None, exclude_globs: List[str] | None = None, max_workers: int | None = 10)[source]#

Bases: object

Automatically discovers and summarizes PyTorch Lightning CSVLogger metrics from Hydra multirun sweeps.

Handles arbitrary directory layouts, sparse metrics, multiple versions (preemption/resume), and aggregates config/hparams metadata into a single DataFrame. Features: - Recursively finds metrics CSVs using common patterns. - Infers run root by searching for .hydra/config.yaml (falls back to metrics parent). - Handles multiple metrics files per run with configurable grouping strategies:

latest_mtime, latest_epoch, latest_step, or merge.

  • Robust CSV parsing (delimiter sniffing, repeated-header cleanup, type coercion).

  • Sparse-metrics aware: last values are last non-NaN; best values ignore NaNs.

  • Loads and flattens Hydra config, overrides (list or dict), and hparams metadata.

Args: base_dir: Root directory to search for metrics files. monitor_keys: Metrics to summarize (if None, auto-infer). include_globs: Only include files matching these globs (relative to base_dir). exclude_globs: Exclude files matching these globs (e.g., ‘/checkpoints/’). forward_fill_last: If True, forward-fill the frame (after sorting) before computing last.* summaries.

METRICS_PATTERNS = ['**/metrics.csv', '**/csv/metrics.csv', '**/*metrics*.csv']#
collect(base_dir) DataFrame[source]#

Discover, summarize, and aggregate all runs into a DataFrame.

Args: base_dir: Root directory to search for metrics files.

Returns:

One row per selected metrics source (per file or per run root), with flattened columns such as: - last.val_accuracy - best.val_loss, best.val_loss.step, best.val_loss.epoch - config.optimizer.lr, override.0 (or override as joined string), hparams.seed Also includes ‘metrics_path’ and ‘run_root’.

Return type:

pd.DataFrame

stable_pretraining.utils.read_csv_logger.load_best_compressed(filename: str) DataFrame[source]#

Load a DataFrame from a compressed file with automatic format detection.

Automatically detects the file format from the extension and tries multiple read methods if needed. This is a smart loader that works with any compressed DataFrame file, not just those created by save_best_compressed.

Parameters:

filename – Path to the compressed file. Standard extensions are recognized: - .parquet → Parquet format - .feather → Feather format - .pkl, .pickle → Pickle format - .csv, .csv.gz, .csv.bz2, etc. → CSV format

Returns:

Loaded DataFrame with original dtypes preserved.

Raises:
Supported Formats:
  • Parquet (compression auto-detected)

  • Feather (compression auto-detected)

  • CSV (compression auto-detected from extension)

  • Pickle (compression auto-detected)

Notes

  • Extension-based detection is tried first for speed

  • If extension is ambiguous/missing, tries all formats sequentially

  • Compression is handled automatically by pandas

  • Works with any properly formatted DataFrame file, not just from this module

Examples

>>> # Load with clear extension
>>> df = load_best_compressed("data.parquet")
>>> # Load compressed CSV
>>> df = load_best_compressed("results.csv.gz")
>>> # Load with full path
>>> df = load_best_compressed("/path/to/experiment.feather")
>>> # Even works with ambiguous extensions (tries all methods)
>>> df = load_best_compressed("mystery_file.dat")

See also

save_best_compressed: Saves DataFrame with optimal compression

stable_pretraining.utils.read_csv_logger.save_best_compressed(df: DataFrame, base_path: str = 'data') str[source]#

Optimize DataFrame and save using the most space-efficient format/compression combination.

This function: 1. Optimizes DataFrame dtypes to reduce memory usage 2. Tries multiple format/compression combinations in parallel 3. Keeps only the smallest resulting file with a clean filename 4. Automatically cleans up all trial files

The following format/compression combinations are tested: - Parquet: brotli, zstd, gzip, snappy - Feather: zstd, lz4, uncompressed - CSV: gzip, bz2, xz, zstd, zip - Pickle: infer

Parameters:
  • df – DataFrame to save. Will be optimized before saving.

  • base_path – Base path/name for output file. Any extensions will be automatically removed and replaced with the optimal format. Default: ‘data’

Returns:

Filename of the best-compressed file (smallest size) with clean extension.

Raises:

RuntimeError – If all save attempts fail (no files were created).

Notes

  • Uses parallel processing (ThreadPoolExecutor) for I/O-bound operations

  • All trial files except the smallest are automatically deleted

  • Original DataFrame is not modified (optimization works on a copy)

  • Final file uses standard extensions (e.g., .parquet, .csv.gz, .feather)

  • Any extensions in base_path are automatically stripped

  • If multiple formats produce the same size, the first one sorted is kept

Performance:
  • Time: Depends on number of trials and DataFrame size

  • Memory: Peak usage ~2x DataFrame size (original + optimized copy)

  • Disk: Temporarily creates all trial files before cleanup

Examples

>>> df = pd.DataFrame(
...     {
...         "timestamp": pd.date_range("2024-01-01", periods=1000),
...         "value": np.random.randn(1000),
...         "category": np.random.choice(["A", "B", "C"], 1000),
...     }
... )
>>> # Extension automatically added
>>> best_file = save_best_compressed(df, "results/experiment_01")
>>> print(best_file)
'results/experiment_01.parquet'
>>> # Extensions in input are stripped
>>> best_file = save_best_compressed(df, "data.csv")
>>> print(best_file)
'data.parquet'  # .csv was stripped, optimal format chosen
>>> # Works with multiple extensions
>>> best_file = save_best_compressed(df, "output.csv.gz")
>>> print(best_file)
'output.feather'  # All extensions stripped

See also

_optimize_dataframe: For details on DataFrame optimization _get_trials: For the complete list of format/compression combinations load_best_compressed: Smart loader that auto-detects format

stable_pretraining.utils.timm_to_hf_hub module#

stable_pretraining.utils.timm_to_hf_hub.push_timm_to_hf(model_name: str, model: Module, repo_id: str, hf_token: str | None = None, private: bool = False, validate: bool = True, batch_size: int = 2, atol: float = 0.0001, rtol: float = 0.0001, device: str | None = None, strict: bool = False) str[source]#

stable_pretraining.utils.visualization module#

Visualization utilities for SSL experiments.

stable_pretraining.utils.visualization.escape_labels(idx_or_cols)[source]#

Recursively escape all labels in a pandas Index or MultiIndex.

stable_pretraining.utils.visualization.format_df_to_latex(df, caption=None, label=None, bold='row', na_rep='–', sort_index=False, sort_columns=False, column_format=None, position='htbp', escape_headers=True, show_percent_symbol=False, unit_annotation='caption')[source]#

Format a MultiIndex DataFrame for LaTeX export with percent formatting (no % symbol).

Escapes LaTeX special characters in all headers if escape_headers=True.

stable_pretraining.utils.visualization.imshow_with_grid(ax, G: ndarray | Tensor, linewidth: float | None = 0.4, color: str | tuple | None = 'black', bars=[], **kwargs)[source]#

Display a matrix with grid lines overlaid.

Parameters:
  • ax – Matplotlib axes to plot on

  • G – Matrix to display

  • linewidth – Width of grid lines

  • color – Color of grid lines

  • bars – List of bar specifications for highlighting regions

  • **kwargs – Additional arguments for imshow

Returns:

The image object from imshow

stable_pretraining.utils.visualization.latex_escape(s)[source]#

Escape LaTeX special characters in a string.

stable_pretraining.utils.visualization.visualize_images_graph(x, G, zoom_on=8)[source]#

Visualize images and their similarity graph with zoom detail.

Creates a visualization showing: - A grid of sample images - The full similarity matrix - A zoomed-in view of the top-left portion of the matrix - Connection lines between the views

Parameters:
  • x – List or tensor of images

  • G – Similarity/adjacency matrix

  • zoom_on – Number of rows/columns to show in zoomed view

Module contents#

Stable-pretraining utilities package.

This package provides various utilities for self-supervised learning experiments including distributed training helpers, custom autograd functions, neural network modules, stable linear algebra operations, data generation, visualization, and configuration management.

class stable_pretraining.utils.BatchNorm1dNoBias(*args, **kwargs)[source]#

Bases: BatchNorm1d

BatchNorm1d with learnable scale but no learnable bias (center=False).

This is used in contrastive learning methods like SimCLR where the final projection layer uses batch normalization with scale (gamma) but without bias (beta). This follows the original SimCLR implementation where the bias term is removed from the final BatchNorm layer.

The bias is frozen at 0 and set to non-trainable, while the weight (scale) parameter remains learnable.

Example

```python # SimCLR-style projector projector = nn.Sequential(

nn.Linear(2048, 2048, bias=False), nn.BatchNorm1d(2048), nn.ReLU(inplace=True), nn.Linear(2048, 128, bias=False), spt.utils.nn_modules.BatchNorm1dNoBias(128), # Final layer: no bias

)#

Note

This is equivalent to TensorFlow’s BatchNorm with center=False, scale=True.

class stable_pretraining.utils.CSVLogAutoSummarizer(agg: callable | None = None, monitor_keys: List[str] | None = None, include_globs: List[str] | None = None, exclude_globs: List[str] | None = None, max_workers: int | None = 10)[source]#

Bases: object

Automatically discovers and summarizes PyTorch Lightning CSVLogger metrics from Hydra multirun sweeps.

Handles arbitrary directory layouts, sparse metrics, multiple versions (preemption/resume), and aggregates config/hparams metadata into a single DataFrame. Features: - Recursively finds metrics CSVs using common patterns. - Infers run root by searching for .hydra/config.yaml (falls back to metrics parent). - Handles multiple metrics files per run with configurable grouping strategies:

latest_mtime, latest_epoch, latest_step, or merge.

  • Robust CSV parsing (delimiter sniffing, repeated-header cleanup, type coercion).

  • Sparse-metrics aware: last values are last non-NaN; best values ignore NaNs.

  • Loads and flattens Hydra config, overrides (list or dict), and hparams metadata.

Args: base_dir: Root directory to search for metrics files. monitor_keys: Metrics to summarize (if None, auto-infer). include_globs: Only include files matching these globs (relative to base_dir). exclude_globs: Exclude files matching these globs (e.g., ‘/checkpoints/’). forward_fill_last: If True, forward-fill the frame (after sorting) before computing last.* summaries.

METRICS_PATTERNS = ['**/metrics.csv', '**/csv/metrics.csv', '**/*metrics*.csv']#
collect(base_dir) DataFrame[source]#

Discover, summarize, and aggregate all runs into a DataFrame.

Args: base_dir: Root directory to search for metrics files.

Returns:

One row per selected metrics source (per file or per run root), with flattened columns such as: - last.val_accuracy - best.val_loss, best.val_loss.step, best.val_loss.epoch - config.optimizer.lr, override.0 (or override as joined string), hparams.seed Also includes ‘metrics_path’ and ‘run_root’.

Return type:

pd.DataFrame

class stable_pretraining.utils.EMA(alpha: float)[source]#

Bases: Module

Exponential Moving Average module.

Maintains an exponential moving average of input tensors.

Parameters:

alpha – Smoothing factor between 0 and 1. 0 = no update (always return first value) 1 = no smoothing (always return current value)

forward(item)[source]#

Update EMA and return smoothed value.

Parameters:

item – New tensor to incorporate into the average

Returns:

Exponentially smoothed tensor

class stable_pretraining.utils.FullGatherLayer(*args, **kwargs)[source]#

Bases: Function

Gather tensors from all process and support backward propagation.

Supports backward propagation for the gradients across processes.

static backward(ctx, grad)[source]#

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, x)[source]#

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class stable_pretraining.utils.GDriveUploader(folder_name: str, credentials_path: str, parent_folder_id: str | None = None, callback: Callable[[str, str | None, bool], None] | None = None)[source]#

Bases: object

Background Google Drive uploader with automatic folder versioning.

All uploads happen in a background thread - your code never blocks! Automatically creates versioned folders (MyFolder_v2, MyFolder_v3, etc).

Parameters:
  • folder_name – Name of the Drive folder to create

  • credentials_path – Path to service account JSON credentials

  • parent_folder_id – Optional parent folder ID (None = root)

  • callback – Optional function(file_path, file_id, success) called on completion

Example

>>> def notify(path, fid, ok):
...     print(f"{'✓' if ok else '✗'} {path}")
>>>
>>> uploader = GDriveUploader("Backups", "creds.json", callback=notify)
>>> uploader.upload_file("data.csv")  # Non-blocking!
>>> uploader.upload_file("logs.txt")
>>> uploader.wait_for_uploads()  # Wait for all to finish
get_folder_id() str[source]#

Get the Drive folder ID.

get_folder_url() str[source]#

Get the Drive folder URL.

get_queue_size() int[source]#

Get number of pending uploads.

upload_file(file_path: str, custom_name: str | None = None, subfolder_id: str | None = None) None[source]#

Queue a file for background upload. Returns immediately!

Parameters:
  • file_path – Path to file to upload

  • custom_name – Optional custom name in Drive

  • subfolder_id – Optional subfolder ID (defaults to main folder)

Example

>>> uploader.upload_file("video.mp4")
>>> uploader.upload_file("report.pdf", custom_name="Q4_report.pdf")
wait_for_uploads(timeout: float | None = None) bool[source]#

Wait for all queued uploads to complete.

Parameters:

timeout – Max seconds to wait (None = wait forever)

Returns:

True if completed, False if timeout/error

Example

>>> uploader.upload_file("file1.txt")
>>> uploader.upload_file("file2.txt")
>>> uploader.wait_for_uploads(timeout=300)  # Wait max 5 min
class stable_pretraining.utils.ImageToVideoEncoder(encoder: Module)[source]#

Bases: Module

Wrapper to apply an image encoder to video data by processing each frame independently.

This module takes video data with shape (batch, time, channel, height, width) and applies an image encoder to each frame, returning the encoded features.

Parameters:

encoder (torch.nn.Module) – The image encoder module to apply to each frame.

forward(video)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class stable_pretraining.utils.L2Norm(*args, **kwargs)[source]#

Bases: Module

L2 normalization layer that normalizes input to unit length.

Normalizes the input tensor along the last dimension to have unit L2 norm. Commonly used in DINO before the prototypes layer.

Example

```python projector = nn.Sequential(

nn.Linear(512, 2048), nn.GELU(), nn.Linear(2048, 256), spt.utils.nn_modules.L2Norm(), # Normalize to unit length nn.Linear(256, 4096, bias=False), # Prototypes

)#

forward(x: Tensor) Tensor[source]#

Normalize input to unit L2 norm.

Parameters:

x – Input tensor […, D]

Returns:

L2-normalized tensor […, D] where each D-dimensional vector has unit length

class stable_pretraining.utils.Normalize(*args, **kwargs)[source]#

Bases: Module

Normalize tensor and scale by square root of number of elements.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class stable_pretraining.utils.OrderedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#

Bases: Module

A queue that maintains insertion order of elements.

Similar to UnsortedQueue but tracks the order in which items were inserted, allowing retrieval in the original insertion order even after wraparound.

Parameters:
  • max_length – Maximum number of elements to store in the queue

  • shape – Shape of each element (excluding batch dimension). Can be int or tuple

  • dtype – Data type of the tensors to store

append(item)[source]#

Append item(s) to the queue with order tracking.

Parameters:

item – Tensor to append. First dimension is batch size.

Returns:

Current contents of the queue in insertion order

get()[source]#

Get current contents sorted by insertion order.

Returns:

Tensor containing items sorted by their original insertion order

get_unsorted()[source]#

Get current contents without sorting (like UnsortedQueue).

Returns:

Tensor containing items in buffer order

load_state_dict(state_dict, strict=True, assign=False)[source]#

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – When set to False, the properties of the tensors in the current module are preserved whereas setting it to True preserves properties of the Tensors in the state dict. The only exception is the requires_grad field of Parameter for which the value from the module is preserved. Default: False

Returns:

  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

class stable_pretraining.utils.UnsortedQueue(max_length: int, shape: int | Iterable[int] = None, dtype=None)[source]#

Bases: Module

A queue data structure that stores tensors with a maximum length.

This module implements a circular buffer that stores tensors up to a maximum length. When the queue is full, new items overwrite the oldest ones.

Parameters:
  • max_length – Maximum number of elements to store in the queue

  • shape – Shape of each element (excluding batch dimension). Can be int or tuple

  • dtype – Data type of the tensors to store

append(item)[source]#

Append item(s) to the queue.

Parameters:

item – Tensor to append. First dimension is batch size.

Returns:

Current contents of the queue

get()[source]#

Get current contents of the queue.

Returns:

Tensor containing all items currently in the queue

stable_pretraining.utils.adapt_resnet_for_lowres(model)[source]#

Adapt a ResNet model for low resolution images.

Modifies the first convolution layer to use 3x3 kernels with stride 1 and removes the max pooling layer.

Parameters:

model – ResNet model to adapt

Returns:

Modified model

stable_pretraining.utils.all_gather(tensor, *args, **kwargs)[source]#

Gather tensors from all processes.

Parameters:
  • tensor – The tensor to gather

  • *args – Additional arguments for all_gather

  • **kwargs – Additional keyword arguments for all_gather

Returns:

Tuple containing the gathered tensors

stable_pretraining.utils.all_reduce(tensor, *args, **kwargs)[source]#

Reduce tensors across all processes.

Parameters:
  • tensor – The tensor to reduce

  • *args – Additional arguments for all_reduce

  • **kwargs – Additional keyword arguments for all_reduce

Returns:

The reduced tensor

stable_pretraining.utils.broadcast_param_to_list(param: Any, target_length: int, param_name: str) List[Any][source]#

Broadcast a parameter value to create a list of specified length.

This function handles the common pattern of accepting either: - None: creates a list of None values - A single value: broadcasts to all positions - A single-element list/tuple: broadcasts the element to all positions - A list/tuple of correct length: returns as-is

Parameters:
  • param – The parameter to broadcast (can be None, single value, or list/tuple)

  • target_length – The desired length of the output list

  • param_name – Name of the parameter for error messages

Returns:

List of values with length matching target_length

Raises:

ValueError – If param is a list/tuple with length > 1 that doesn’t match target_length

Examples

>>> broadcast_param_to_list(None, 3, "dims")
[None, None, None]
>>> broadcast_param_to_list(5, 3, "dims")
[5, 5, 5]
>>> broadcast_param_to_list([5], 3, "dims")
[5, 5, 5]
>>> broadcast_param_to_list([1, 2, 3], 3, "dims")
[1, 2, 3]
stable_pretraining.utils.compute_pairwise_distances(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean') Tensor[source]#

Compute pairwise distances between two sets of vectors.

Parameters:
  • x – Tensor of shape (n, d) containing n vectors of dimension d

  • y – Tensor of shape (m, d) containing m vectors of dimension d

  • metric – Distance metric to use. Options: - “euclidean”: L2 distance - “squared_euclidean”: Squared L2 distance - “cosine”: Cosine distance (1 - cosine_similarity) - “manhattan”: L1 distance

Returns:

Distance matrix of shape (n, m) where element (i, j) is the distance between x[i] and y[j]

stable_pretraining.utils.compute_pairwise_distances_chunked(x: Tensor, y: Tensor, metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean', chunk_size: int = 1024) Tensor[source]#

Memory-efficient computation of pairwise distances using chunking.

Parameters:
  • x – Tensor of shape (n, d) containing n vectors of dimension d

  • y – Tensor of shape (m, d) containing m vectors of dimension d

  • metric – Distance metric to use

  • chunk_size – Process y in chunks of this size to save memory

Returns:

Distance matrix of shape (n, m)

stable_pretraining.utils.detach_tensors(obj: Any) Any[source]#

Recursively traverse an object and return an equivalent structure with all torch tensors detached.

  • Preserves structure, types, and shared references.

  • Handles cycles and arbitrary Python objects (including __dict__ and __slots__).

  • Does not mutate the input; only rebuilds containers if needed.

  • torch.nn.Parameter is replaced with a detached Tensor (not Parameter).

  • Optionally supports attrs classes if ‘attr’ is installed.

Parameters:

obj – The input object (can be arbitrarily nested).

Returns:

A new object with all torch tensors detached, or the original object if no tensors found.

Performance notes:
  • Uses memoization to avoid redundant work and preserve shared/cyclic structure.

  • Avoids unnecessary copies: unchanged subtrees are returned as-is (same id).

  • Shallow-copies objects with __dict__ or __slots__ (does not call __init__).

stable_pretraining.utils.dict_values(**kwargs)[source]#

Convert keyword arguments to a list of values.

Returns:

List of values from the provided keyword arguments

stable_pretraining.utils.execute_from_config(manager, cfg)[source]#

Execute a function with support for submitit job submission.

If submitit configuration is present, submits the job to a cluster. Otherwise executes locally.

Parameters:
  • manager – Function or callable to execute

  • cfg – Configuration dictionary

Returns:

Result of the executed function

stable_pretraining.utils.find_module(model: Module, module: Module)[source]#

Find all instances of a module type in a model.

Parameters:
  • model – Model to search in

  • module – Module class to search for

Returns:

Tuple of (names, modules) where names are the module paths and modules are the actual module instances

stable_pretraining.utils.format_df_to_latex(df, caption=None, label=None, bold='row', na_rep='–', sort_index=False, sort_columns=False, column_format=None, position='htbp', escape_headers=True, show_percent_symbol=False, unit_annotation='caption')[source]#

Format a MultiIndex DataFrame for LaTeX export with percent formatting (no % symbol).

Escapes LaTeX special characters in all headers if escape_headers=True.

stable_pretraining.utils.generate_dae_samples(x, n, eps, num_workers=10)[source]#

Generate samples for Denoising Autoencoder (DAE) training.

Parameters:
  • x – List of input images

  • n – Number of noisy versions per image

  • eps – Noise level (variance)

  • num_workers – Number of parallel workers

Returns:

Tuple of (noisy_images, similarity_matrix)

stable_pretraining.utils.generate_dm_samples(x, n, betas, i, num_workers=10)[source]#

Generate samples for Diffusion Model training.

Parameters:
  • x – List of input images

  • n – Number of noisy versions per timestep

  • betas – Noise schedule beta values

  • i – Timestep indices to use

  • num_workers – Number of parallel workers

Returns:

Tuple of (noisy_images, similarity_matrix)

stable_pretraining.utils.generate_ssl_samples(x, n, num_workers=10)[source]#

Generate augmented samples for self-supervised learning.

Creates n augmented versions of each image.

Parameters:
  • x – List of input images

  • n – Number of augmented versions per image

  • num_workers – Number of parallel workers

Returns:

Tuple of (augmented_images, similarity_matrix)

stable_pretraining.utils.generate_sup_samples(x, y, n, num_workers=10)[source]#

Generate samples for supervised learning with class structure.

Only includes classes with at least n samples.

Parameters:
  • x – List of input images

  • y – Class labels

  • n – Minimum samples per class

  • num_workers – Number of parallel workers

Returns:

Tuple of (processed_images, class_similarity_matrix)

stable_pretraining.utils.get_data_from_batch_or_outputs(key: Iterable[str] | str, batch: Dict[str, Any], outputs: Dict[str, Any] | None = None, caller_name: str = 'Callback') Any | None[source]#

Get data from either outputs or batch dictionary.

In PyTorch Lightning, the outputs parameter in callbacks contains the return value from training_step/validation_step, while batch contains the original input. Since forward methods may modify batch in-place but Lightning creates a copy for outputs, we need to check both.

Parameters:
  • key – The key(s) to look for in the dictionaries

  • batch – The original batch dictionary

  • outputs – The outputs dictionary from training/validation step

  • caller_name – Name of the calling function/class for logging

Returns:

The data associated with the key, or None if not found

stable_pretraining.utils.get_required_fn_parameters(fn)[source]#

Get the list of required parameters for a function.

Parameters:

fn – The function to inspect

Returns:

List of parameter names that don’t have default values

stable_pretraining.utils.is_dist_avail_and_initialized()[source]#

Check if distributed training is available and initialized.

Returns:

True if distributed is available and initialized, False otherwise

Return type:

bool

stable_pretraining.utils.load_hparams_from_ckpt(path: str) Any[source]#

Loads a checkpoint safely in both distributed and non-distributed settings.

  • If torch.distributed is initialized, only src_rank loads from disk, then broadcasts to all other ranks.

  • If not, loads directly from disk.

Parameters:

path – Path to checkpoint file.

Returns:

The loaded hparams.

stable_pretraining.utils.replace_module(model, replacement_mapping)[source]#

Replace modules in a model based on a mapping function.

Parameters:
  • model – PyTorch model to modify

  • replacement_mapping – Function that takes (name, module) and returns the replacement module

Returns:

Modified model

stable_pretraining.utils.rgetattr(obj, attr)[source]#

Recursively get an attribute using dot notation.

Parameters:
  • obj – Object to get attribute from

  • attr – Attribute path (e.g., “module.layer.weight”)

Returns:

The requested attribute value

stable_pretraining.utils.rsetattr(obj, attr, val)[source]#

Recursively set an attribute using dot notation.

Parameters:
  • obj – Object to set attribute on

  • attr – Attribute path (e.g., “module.layer.weight”)

  • val – Value to set

stable_pretraining.utils.with_hf_retry_ratelimit(func, *args, delay=10, max_attempts=100, **kwargs)[source]#

Calls the given function with retry logic for HTTP 429 (Too Many Requests) errors.

This function attempts to call func(*args, **kwargs). If a rate-limiting error (HTTP 429) is encountered—detected via exception type, status code, or error message—it will wait for the duration specified by the HTTP Retry-After header (if present), or fall back to the delay parameter, and then retry. Retries continue up to max_attempts times. Non-429 errors are immediately re-raised. If all attempts fail due to 429, the last exception is raised.

Exceptions handled:
  • huggingface_hub.utils.HfHubHTTPError

  • requests.exceptions.HTTPError

  • OSError

429 detection is performed by checking the exception’s response.status_code (if available) or by searching for ‘429’ or ‘Too Many Requests’ in the exception message.

Parameters:
  • func (callable) – The function to call.

  • *args – Positional arguments to pass to func.

  • delay (int, optional) – Default wait time (in seconds) between retries if Retry-After is not provided. Defaults to 10.

  • max_attempts (int, optional) – Maximum number of attempts before giving up. Defaults to 100.

  • **kwargs – Keyword arguments to pass to func.

Returns:

The return value of func(*args, **kwargs) if successful.

Raises:

Exception – The original exception if a non-429 error occurs, or if all attempts fail.

Example

>>> from transformers import AutoModel
>>> model = with_hf_retry_ratelimit(
...     AutoModel.from_pretrained,
...     "facebook/ijepa_vith14_1k",
...     delay=10,
...     max_attempts=5,
... )