stable_pretraining.data package

Contents

stable_pretraining.data package#

Submodules#

stable_pretraining.data.collate module#

class stable_pretraining.data.collate.Collator(G_from=None)[source]#

Bases: object

Custom collate function that optionally builds an affinity (or “graph”) matrix based on a specified field.

stable_pretraining.data.dataset_stats module#

Dataset statistics for normalization.

This module contains pre-computed mean and standard deviation values for various common datasets, used for data normalization during preprocessing.

stable_pretraining.data.datasets module#

Dataset classes for real data sources.

This module provides dataset wrappers and utilities for working with real data sources including PyTorch datasets, HuggingFace datasets, and dataset subsets.

class stable_pretraining.data.datasets.Dataset(transform=None)[source]#

Bases: Dataset

Base dataset class with transform support and PyTorch Lightning integration.

process_sample(sample, **kwargs)[source]#
set_pl_trainer(trainer: Trainer)[source]#
class stable_pretraining.data.datasets.FromTorchDataset(dataset, names, transform=None, add_sample_idx=True)[source]#

Bases: Dataset

Wrapper for PyTorch datasets with custom column naming and transforms.

Parameters:
  • dataset – PyTorch dataset to wrap

  • names – List of names for each element returned by the dataset

  • transform – Optional transform to apply to samples

  • add_sample_idx – If True, automatically adds ‘sample_idx’ field to each sample

property column_names#
class stable_pretraining.data.datasets.HFDataset(*args, transform=None, rename_columns=None, remove_columns=None, **kwargs)[source]#

Bases: Dataset

Hugging Face dataset wrapper with transform and column manipulation support.

property column_names#
is_saved_with_save_to_disk(path)[source]#
class stable_pretraining.data.datasets.Subset(dataset: Dataset, indices: Sequence[int])[source]#

Bases: Dataset

Subset of a dataset at specified indices.

Parameters:
  • dataset (Dataset) – The whole Dataset

  • indices (sequence) – Indices in the whole set selected for subset

property column_names#
dataset: Dataset#
indices: Sequence[int]#

stable_pretraining.data.download module#

Download utilities for fetching datasets and model weights.

This module provides functions for downloading files from URLs with progress tracking, caching, and concurrent download support.

stable_pretraining.data.download.bulk_download(urls: Iterable[str], dest_folder: str | Path, backend: str = 'filesystem', cache_dir: str = '~/.stable_pretraining/')[source]#

Download multiple files concurrently.

Example

import stable_pretraining stable_pretraining.data.bulk_download([

], “todelete”)

Parameters:
  • urls (Iterable[str]) – List of URLs to download

  • dest_folder (Union[str, Path]) – Destination folder for downloads

  • backend (str, optional) – Storage backend type. Defaults to “filesystem”.

  • cache_dir (str, optional) – Cache directory path. Defaults to “~/.stable_pretraining/”.

stable_pretraining.data.download.download(url, dest_folder, backend='filesystem', cache_dir='~/.stable_pretraining/', progress_bar=True, _progress_dict=None, _task_id=None)[source]#

Download a file from a URL with progress tracking.

Parameters:
  • url – URL to download from

  • dest_folder – Destination folder for the download

  • backend – Storage backend type

  • cache_dir – Cache directory path

  • progress_bar – Whether to show progress bar

  • _progress_dict – Internal dictionary for progress tracking

  • _task_id – Internal task ID for bulk downloads

Returns:

Path to the downloaded file or None if download failed

stable_pretraining.data.masking module#

stable_pretraining.data.masking.multi_block_mask(height: int, width: int, block_scales: list[tuple[float, float]] = [(0.85, 1.0), (0.15, 0.2), (0.15, 0.2), (0.15, 0.2), (0.15, 0.2)], aspect_ratios: list[tuple[float, float]] = [(1.0, 1.0), (0.75, 1.5), (0.75, 1.5), (0.75, 1.5), (0.75, 1.5)], min_keep: int = 1, seed: int = 0) list[Tensor, ...][source]#

Generates a list of distinct, randomly placed, block-shaped binary masks.

This function creates a series of block masks based on provided scale and aspect ratio specifications. For each pair of (scale, aspect_ratio), it first samples a block size (height, width). It then places this block at a random location within the grid of the specified height and width.

The process is repeated for all items in the input lists to produce a list of independent masks.

Example

>>> # xdoctest: +SKIP
>>> # Generate masks for a 14x14 patch grid
>>> masks = multi_block_mask(height=14, width=14)
>>> len(masks)
5
>>> masks[0].nonzero().size(0)
169
>>> masks[1].nonzero().size(0)
30
Parameters:
  • height (int) – The height of the grid to generate masks for (in patches).

  • width (int) – The width of the grid to generate masks for (in patches).

  • block_scales (list[tuple[float, float]]) – A list of tuples, each defining the min/max scale of a block’s area relative to the total grid area.

  • aspect_ratios (list[tuple[float, float]]) – A list of tuples, each defining the min/max aspect ratio (height/width) for a corresponding block.

  • min_keep (int) – The minimum number of `1`s required for a valid block mask.

  • seed (int) – A seed for the random number generator to ensure reproducibility.

Returns:

A list of 2D binary masks. Each tensor has a shape of

(height, width), where `1`s indicate the masked block and `0`s are the background.

Return type:

list[torch.Tensor]

stable_pretraining.data.module module#

class stable_pretraining.data.module.DataModule(train: dict | DictConfig | DataLoader | None = None, test: dict | DictConfig | DataLoader | None = None, val: dict | DictConfig | DataLoader | None = None, predict: dict | DictConfig | DataLoader | None = None, **kwargs)[source]#

Bases: LightningDataModule

PyTorch Lightning DataModule for handling train/val/test/predict dataloaders.

load_state_dict(state_dict)[source]#

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters:

state_dict – the datamodule state returned by state_dict.

predict_dataloader()[source]#

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

set_pl_trainer(trainer: Trainer)[source]#
setup(stage)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
state_dict()[source]#

Called when saving a checkpoint, implement to generate and save datamodule state.

Returns:

A dictionary containing datamodule state.

teardown(stage: str)[source]#

Called at the end of fit (train + validate), validate, test, or predict.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

test_dataloader()[source]#

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader()[source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

class stable_pretraining.data.module.DictFormat(dataset: Iterable, names: Iterable)[source]#

Bases: Dataset

Format dataset with named columns for dictionary-style access.

Parameters:
  • dataset (Iterable) – Dataset to be wrapped.

  • names (Iterable) – Column names for the dataset.

stable_pretraining.data.sampler module#

class stable_pretraining.data.sampler.RandomBatchSampler(batch_size: int, length_or_dataset: Dataset | int, *args, **kwargs)[source]#

Bases: Sampler[List[int]]

Wraps another sampler to yield a mini-batch of indices.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class stable_pretraining.data.sampler.RepeatedRandomSampler(data_source_or_len: int | Iterable, n_views: int = 1, replacement: bool = False, seed: int = 0, pass_view_idx: bool = False)[source]#

Bases: DistributedSampler

Sampler that repeats each dataset index consecutively for multi-view learning.

IMPORTANT: This sampler repeats each index n_views times in a row, creating sequences like [0,0,0,0, 1,1,1,1, 2,2,2,2, …] for n_views=4. This means: - The DataLoader will load the SAME image multiple times consecutively - Each repeated index goes through the transform pipeline separately - BATCH SIZE: The batch_size in DataLoader refers to total augmented samples.

For example: batch_size=128 with n_views=8 means only 16 unique images, each appearing 8 times with different augmentations

Designed to work with RoundRobinMultiViewTransform which uses a counter to apply different augmentations to each repeated occurrence of the same image.

Example behavior with n_views=3:

Dataset indices: [0, 1, 2, 3, 4] Sampler output: [0,0,0, 1,1,1, 2,2,2, 3,3,3, 4,4,4]

Parameters:
  • data_source (Dataset) – dataset to sample from

  • n_views (int) – number of times to repeat each index consecutively, default=1

  • replacement (bool) – samples are drawn on-demand with replacement if True, default=``False``

  • seed (int) – random seed for shuffling

  • pass_view_idx (bool) – whether to pass the view index to the dataset getitem

Note: For an alternative approach that loads each image once, consider using MultiViewTransform with a standard sampler.

class stable_pretraining.data.sampler.SupervisedBatchSampler(batch_size: int, n_views: int, targets_or_dataset: Dataset | list, *args, **kwargs)[source]#

Bases: Sampler[List[int]]

Wraps another sampler to yield a mini-batch of indices.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

stable_pretraining.data.synthetic_data module#

Synthetic and simulated data generators.

This module provides various synthetic data generators including manifold datasets, noise generators, statistical models, and simulated environments for testing and experimentation purposes.

class stable_pretraining.data.synthetic_data.Categorical(values: list | Tensor, probabilities: list | Tensor)[source]#

Bases: Module

Categorical distribution for sampling discrete values with given probabilities.

sample(*args, **kwargs)[source]#
class stable_pretraining.data.synthetic_data.ExponentialMixtureNoiseModel(rates, prior, upper_bound=inf)[source]#

Bases: Module

Exponential mixture noise model for data augmentation or sampling.

sample(*args, **kwargs)[source]#
class stable_pretraining.data.synthetic_data.ExponentialNormalNoiseModel(rate, mean, std, prior, upper_bound=inf)[source]#

Bases: Module

Exponential-normal noise model combining exponential and normal distributions.

sample(*args, **kwargs)[source]#
class stable_pretraining.data.synthetic_data.GMM(num_components=5, num_samples=100, dim=2)[source]#

Bases: Dataset

Gaussian Mixture Model dataset for synthetic data generation.

score(samples)[source]#
class stable_pretraining.data.synthetic_data.MinariEpisodeDataset(dataset)[source]#

Bases: Dataset

Dataset for Minari reinforcement learning data with episode-based access.

NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
property column_names#
nested_step(value, idx)[source]#
set_pl_trainer(trainer)[source]#
class stable_pretraining.data.synthetic_data.MinariStepsDataset(dataset, num_steps=2, transform=None)[source]#

Bases: Dataset

Dataset for Minari reinforcement learning data with step-based access.

NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
property column_names#
nested_step(value, idx)[source]#
stable_pretraining.data.synthetic_data.generate_perlin_noise_2d(shape, res, octaves=1, persistence=0.5, lacunarity=2.0)[source]#

Generate 2D Perlin noise.

Parameters:
  • shape – Output shape (height, width)

  • res – Resolution tuple

  • octaves – Number of octaves for fractal noise

  • persistence – Amplitude multiplier for each octave

  • lacunarity – Frequency multiplier for each octave

Returns:

2D tensor of Perlin noise

stable_pretraining.data.synthetic_data.perlin_noise_3d(x, y, z)[source]#

Generate 3D Perlin noise at given coordinates.

Parameters:
  • x – X coordinate for noise generation

  • y – Y coordinate for noise generation

  • z – Z coordinate for noise generation

Returns:

Perlin noise value at the given coordinates

stable_pretraining.data.synthetic_data.swiss_roll(N, margin=1, sampler_time=Uniform(low: 0.10000000149011612, high: 3.0), sampler_width=Uniform(low: 0.0, high: 1.0))[source]#

Generate Swiss Roll dataset points.

Parameters:
  • N – Number of points to generate

  • margin – Margin parameter for the roll

  • sampler_time – Distribution for sampling time parameter

  • sampler_width – Distribution for sampling width parameter

Returns:

Tensor of shape (N, 3) containing Swiss Roll points

stable_pretraining.data.transforms module#

class stable_pretraining.data.transforms.AdditiveGaussian(sigma, p=1)[source]#

Bases: Transform

Add Gaussian noise to input data.

BYPASS_VALUE = False#
class stable_pretraining.data.transforms.CenterCrop(size, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, CenterCrop

Crop the center of an image to the given size.

class stable_pretraining.data.transforms.ColorJitter(brightness=None, contrast=None, saturation=None, hue=None, p=1, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, ColorJitter

Randomly change brightness, contrast, saturation, and hue of an image.

static get_params(brightness: list[float] | None, contrast: list[float] | None, saturation: list[float] | None, hue: list[float] | None) tuple[Tensor, float | None, float | None, float | None, float | None][source]#

Get the parameters for the randomized transform to be applied on image.

Parameters:
  • brightness (tuple of float (min, max), optional) – The range from which the brightness_factor is chosen uniformly. Pass None to turn off the transformation.

  • contrast (tuple of float (min, max), optional) – The range from which the contrast_factor is chosen uniformly. Pass None to turn off the transformation.

  • saturation (tuple of float (min, max), optional) – The range from which the saturation_factor is chosen uniformly. Pass None to turn off the transformation.

  • hue (tuple of float (min, max), optional) – The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation.

Returns:

The parameters used to apply the randomized transform along with their random order.

Return type:

tuple

class stable_pretraining.data.transforms.Compose(*args)[source]#

Bases: Transform

Compose multiple transforms together in sequence.

class stable_pretraining.data.transforms.Conditional(transform, condition_key, apply_on_true=True)[source]#

Bases: Transform

Apply transform conditionally based on a data dictionary key.

class stable_pretraining.data.transforms.ContextTargetsMultiBlockMask(patch_size=16, context_scale=(0.85, 1.0), context_aspect_ratio=(1.0, 1.0), target_scales=((0.15, 0.2), (0.15, 0.2), (0.15, 0.2), (0.15, 0.2)), target_aspect_ratios=((0.75, 1.5), (0.75, 1.5), (0.75, 1.5), (0.75, 1.5)), min_keep=10, source: str = 'image', target_context: str = 'mask_context', target_targets: str = 'masks_target')[source]#

Bases: Transform

Transform that adds multi-block masks to batch, with multiple target blocks and one disjoint context block.

Parameters:
  • patch_size – Size of the patch in patches

  • num_blocks – Number of blocks to sample

  • context_scale – Scale of the context block

  • aspect_ratio – Aspect ratio of the blocks

  • min_keep – Minimum number of patches that must be in the block

class stable_pretraining.data.transforms.ControlledTransform(transform: callable, seed_offset: int = 0, key: str | None = 'idx')[source]#

Bases: Transform

Face Landmarks dataset.

class stable_pretraining.data.transforms.GaussianBlur(kernel_size, sigma=(0.1, 2.0), p=1, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, GaussianBlur

Apply Gaussian blur to image with random sigma values.

static get_params(sigma_min: float, sigma_max: float) float[source]#

Choose sigma for random gaussian blurring.

Parameters:
  • sigma_min (float) – Minimum standard deviation that can be chosen for blurring kernel.

  • sigma_max (float) – Maximum standard deviation that can be chosen for blurring kernel.

Returns:

Standard deviation to be passed to calculate kernel for gaussian blurring.

Return type:

float

class stable_pretraining.data.transforms.Lambda(lambd, source: str = 'image', target: str = 'image')[source]#

Bases: Transform

Applies a lambda callable to target key and store it in source.

class stable_pretraining.data.transforms.MultiViewTransform(transforms)[source]#

Bases: Transform

Creates multiple views from one sample by applying different transforms.

Takes a single sample and applies different transforms to create multiple views, returning a list of complete sample dicts. Preserves all modifications each transform makes (masks, augmentation params, metadata, etc.).

Implementation Note:

This transform uses shallow copy (dict.copy()) for the input sample before applying each transform. This is efficient and safe because: - The shallow copy shares references to the original tensors/objects - Standard transforms create NEW tensors (e.g., through mul(), resize(),

crop()) rather than modifying inputs in-place

  • The original sample remains unchanged

Consequences of shallow copy:
  • Memory efficient: Original tensors are not duplicated unnecessarily

  • Safe with torchvision transforms: All torchvision transforms and our custom transforms follow the pattern of creating new tensors

  • Caution: If using custom transforms that modify tensors in-place (using operations like mul_(), add_() with underscore), views may interfere with each other. Always use non-in-place operations in custom transforms.

Parameters:

transforms – Either a list or dict of transforms. - List: Returns a list of views in the same order - Dict: Returns a dict of views with the same keys

Returns:

  • If transforms is a list: Returns a list of transformed sample dicts

  • If transforms is a dict: Returns a dict of transformed sample dicts with same keys

Each dict contains NEW tensors, not references to the original.

Return type:

Union[List[dict], Dict[str, dict]]

Example

# List input - returns list of views transform = MultiViewTransform([

strong_augmentation, # Creates first view with strong aug weak_augmentation, # Creates second view with weak aug

]) # Input: {“image”: img, “label”: 0} # Output: [{“image”: img_strong, “label”: 0}, {“image”: img_weak, “label”: 0}]

# Dict input - returns dict of named views transform = MultiViewTransform({

“student”: strong_augmentation, “teacher”: weak_augmentation,

}) # Input: {“image”: img, “label”: 0} # Output: {“student”: {“image”: img_strong, “label”: 0}, # “teacher”: {“image”: img_weak, “label”: 0}}

class stable_pretraining.data.transforms.PILGaussianBlur(sigma=None, p=1, source: str = 'image', target: str = 'image')[source]#

Bases: Transform

PIL-based Gaussian blur transform with random sigma sampling.

class stable_pretraining.data.transforms.PatchMasking(patch_size: int = 16, drop_ratio: float = 0.5, source: str = 'image', target: str = 'image', fill_value: float = None, mask_key: str = 'patch_mask')[source]#

Bases: Transform

Randomly masks square patches in an image, similar to patch masking used in Masked Signal Encoding (MSE) tasks.

This transform operates on a dictionary input, applies patch masking to the image found at the specified source key, and writes the masked image to the target key. It also saves a boolean mask matrix (one entry per patch) to the mask_key in the dictionary, indicating which patches were masked (False) or kept (True). The output image remains in the same format as the input (PIL Image or Tensor), and the masking is performed efficiently for both input types.

Parameters:
  • patch_size (int) – The size (in pixels) of each square patch to be masked.

  • drop_ratio (float) – The exact fraction of patches to randomly mask (set to the mask value).

  • source (str) – The key in the input dictionary from which to read the image.

  • target (str) – The key in the output dictionary to which the masked image will be written.

  • mask_key (str) – The key in the output dictionary to which the boolean patch mask will be written.

  • fill_value (float, optional) – The value to use for masked patches. If None, defaults to 0.0 for float tensors, and 128/255.0 for PIL images (mid-gray). Can be set to any float in [0,1] for normalized images.

class stable_pretraining.data.transforms.RGB(source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RGB

Convert image to RGB format.

class stable_pretraining.data.transforms.RandomChannelPermutation(source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RandomChannelPermutation

Randomly permute the channels of an image.

class stable_pretraining.data.transforms.RandomContiguousTemporalSampler(source, target, num_frames, frame_subsampling: int = 1)[source]#

Bases: Transform

Randomly sample contiguous frames from a video sequence.

class stable_pretraining.data.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant', source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RandomCrop

Crop a random portion of image and resize it to given size.

static get_params(img: Tensor, output_size: tuple[int, int]) tuple[int, int, int, int][source]#

Get parameters for crop for a random crop.

Parameters:
  • img (PIL Image or Tensor) – Image to be cropped.

  • output_size (tuple) – Expected output size of the crop.

Returns:

params (i, j, h, w) to be passed to crop for random crop.

Return type:

tuple

class stable_pretraining.data.transforms.RandomGrayscale(p=0.1, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RandomGrayscale

Randomly convert image to grayscale with given probability.

class stable_pretraining.data.transforms.RandomHorizontalFlip(p=0.5, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RandomHorizontalFlip

Horizontally flip the given image randomly with a given probability.

class stable_pretraining.data.transforms.RandomMask(patch_size=16, mask_ratio=0.75, source: str = 'image', target_visible: str = 'mask_visible', target_masked: str = 'mask_masked', target_ids_restore: str = 'ids_restore', target_len_keep: str = 'len_keep')[source]#

Bases: Transform

Creates a random MAE-style mask for an image.

This transform generates a random permutation of all patch indices for an input image. It then splits these indices into two disjoint sets: ‘visible’ and ‘masked’, according to the specified mask_ratio.

It also provides an ids_restore tensor, which can un-shuffle a sequence of patches back to its original 2D grid order. All outputs are added as new keys to the sample dictionary.

Example

>>> # xdoctest: +SKIP
>>> transform = RandomMask(patch_size=16, mask_ratio=0.75)
>>> sample = {"image": torch.randn(3, 224, 224)}
>>> result = transform(sample)
>>> sorted(result.keys())
['image', 'ids_restore', 'len_keep', 'mask_masked', 'mask_visible']
>>> result["len_keep"]
49
>>> result["mask_visible"].shape
torch.Size([49])
Parameters:
  • patch_size (int) – The height and width of each square patch.

  • mask_ratio (float) – The fraction of patches to be masked (e.g., 0.75).

  • source (str) – The key in the sample dict for the source image tensor.

  • target_visible (str) – The key to use when storing visible patch indices.

  • target_masked (str) – The key to use when storing masked patch indices.

  • target_ids_restore (str) – The key to use for the restoration indices.

  • target_len_keep (str) – The key to use for the count of visible patches.

class stable_pretraining.data.transforms.RandomResizedCrop(size: int | Sequence[int], scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (0.75, 1.3333333333333333), interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, antialias: bool | None = True, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RandomResizedCrop

Crop a random portion of image and resize it to given size.

static get_params(img: Tensor, scale: list[float], ratio: list[float]) tuple[int, int, int, int][source]#

Get parameters for crop for a random sized crop.

Parameters:
  • img (PIL Image or Tensor) – Input image.

  • scale (list) – range of scale of the origin size cropped

  • ratio (list) – range of aspect ratio of the origin aspect ratio cropped

Returns:

params (i, j, h, w) to be passed to crop for a random sized crop.

Return type:

tuple

class stable_pretraining.data.transforms.RandomRotation(degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RandomRotation

Rotate image by random angle within specified degrees range.

static get_params(degrees: list[float]) float[source]#

Get parameters for rotate for a random rotation.

Returns:

angle parameter to be passed to rotate for random rotation.

Return type:

float

class stable_pretraining.data.transforms.RandomSolarize(threshold, p=0.5, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, RandomSolarize

Randomly solarize image by inverting pixel values above threshold.

class stable_pretraining.data.transforms.Resize(size, interpolation=2, max_size=None, antialias=True, source='image', target='image')[source]#

Bases: Transform, Resize

Resize image to specified size.

class stable_pretraining.data.transforms.RoundRobinMultiViewTransform(transforms)[source]#

Bases: Transform

Round-robin multi-view transform that cycles through transforms using a counter.

IMPORTANT: This transform is designed to work with RepeatedRandomSampler, where each image index appears multiple times consecutively in the batch. It uses an internal counter to apply different augmentations to each repeated occurrence.

BATCH SIZE NOTE: When using this with RepeatedRandomSampler, the batch_size parameter refers to the total number of augmented samples, NOT the number of unique images. For example, with batch_size=256 and n_views=2, you get 128 unique images, each appearing twice with different augmentations.

How it works: 1. RepeatedRandomSampler produces indices like [0,0,1,1,2,2,…] (for n_views=2) 2. DataLoader loads the same image multiple times 3. This transform applies a different augmentation each time using round-robin

Parameters:

transforms – List of transforms, one for each view. The counter cycles through these transforms in order.

Example

# With RepeatedRandomSampler(dataset, n_views=2) transform = RoundRobinMultiViewTransform([

strong_augmentation, # Applied to 1st occurrence of each image weak_augmentation, # Applied to 2nd occurrence of each image

])

Warning: The internal counter makes this transform stateful and not thread-safe.

class stable_pretraining.data.transforms.RoutingTransform(router: callable, transforms: list | tuple | dict)[source]#

Bases: Transform

Applies a routing callable to conditionally apply a transform from many candidates.

class stable_pretraining.data.transforms.ToImage(dtype=torch.float32, scale=True, mean=None, std=None, source: str = 'image', target: str = 'image')[source]#

Bases: Transform

Convert input to image tensor with optional normalization.

class stable_pretraining.data.transforms.Transform[source]#

Bases: Transform

Base transform class extending torchvision v2.Transform with nested data handling.

get_name(x)[source]#
property name#
nested_get(v, name)[source]#
nested_set(original, value, name)[source]#
single_nested_get(v, name)[source]#
single_nested_set(original, value, name)[source]#
class stable_pretraining.data.transforms.UniformTemporalSubsample(num_samples: int, temporal_dim: int = -3, source: str = 'video', target: str = 'video')[source]#

Bases: Transform

nn.Module wrapper for pytorchvideo.transforms.functional.uniform_temporal_subsample.

forward(x: dict) Tensor[source]#

Do not override this! Use transform() instead.

class stable_pretraining.data.transforms.WrapTorchTransform(transform, source: str = 'image', target: str = 'image')[source]#

Bases: Transform, Lambda

Applies a lambda callable to target key and store it in source.

stable_pretraining.data.transforms.random_seed(seed)[source]#
stable_pretraining.data.transforms.set_seed(seeds)[source]#
stable_pretraining.data.transforms.to_image(input: Tensor | Image | ndarray) Image[source]#

See ToImage for details.

stable_pretraining.data.utils module#

Utility functions for data manipulation and processing.

This module provides utility functions for working with datasets, including view folding for contrastive learning and dataset splitting.

stable_pretraining.data.utils.apply_masks(x: Tensor, *masks: Tensor) Tensor[source]#

Apply one or more masks to a batch of patched images.

This function is generalized to accept any number of mask tensors. If a single mask is provided, the output shape is [B, K, D]. If M masks are provided, the function creates M masked views and concatenates them along the batch dimension, resulting in an output of shape [B*M, K, D].

Example

>>> # xdoctest: +SKIP
>>> x = torch.randn(4, 196, 128)
>>> mask1 = torch.randint(0, 196, (4, 50))
>>> mask2 = torch.randint(0, 196, (4, 50))
>>> # Single mask case
>>> single_view = apply_masks(x, mask1)
>>> single_view.shape
torch.Size([4, 50, 128])
>>> # Multi-mask case
>>> multi_view = apply_masks(x, mask1, mask2)
>>> multi_view.shape
torch.Size([8, 50, 128])
Parameters:
  • x (torch.Tensor) – Input tensor of patches with shape [B, N, D].

  • *masks (torch.Tensor) – A variable number of mask tensors, each a tensor of indices with shape [B, K].

Returns:

The tensor of selected patches. The shape will be

[B, K, D] for a single mask, or [B*M, K, D] for M masks.

Return type:

torch.Tensor

Raises:

ValueError – If no masks are provided.

stable_pretraining.data.utils.fold_views(tensor, idx)[source]#

Fold a tensor containing multiple views back into separate views.

Parameters:
  • tensor – Tensor containing concatenated views

  • idx – Sample indices to determine view boundaries

Returns:

Tuple of tensors, one for each view

stable_pretraining.data.utils.get_num_workers()[source]#

Automatically determine the optimal number of DataLoader workers.

This function computes the ideal number of worker processes for PyTorch DataLoaders based on available CPU resources and distributed training configuration. It provides a zero-configuration approach that works reliably across different environments.

The calculation logic:
  1. Detect CPUs available to this process (respects affinity/cgroups)

  2. Divide by world_size if using DDP (each rank spawns its own workers)

  3. Return the result (always >= 1)

Returns:

Number of DataLoader workers to use. Minimum value is 1.

Return type:

int

Notes

  • Uses os.sched_getaffinity(0) on Linux to respect CPU affinity masks set by job schedulers (SLURM), containers (Docker), or taskset.

  • Falls back to os.cpu_count() on macOS/Windows.

  • In DDP mode, automatically divides by world_size since each process independently spawns workers.

  • Should be called AFTER distributed initialization for accurate results.

Examples

>>> # Simple usage
>>> num_workers = get_num_workers()
>>> loader = DataLoader(dataset, num_workers=num_workers)
>>> # In a Lightning DataModule
>>> class MyDataModule(L.LightningDataModule):
...     def train_dataloader(self):
...         return DataLoader(
...             self.train_dataset,
...             num_workers=get_num_workers(),
...             shuffle=True,
...         )
>>> # Example outputs:
>>> # - 16 CPUs, single GPU: returns 16
>>> # - 32 CPUs, 4 GPUs (DDP): returns 8 per GPU
>>> # - 8 CPUs, 8 GPUs (DDP): returns 1 per GPU
stable_pretraining.data.utils.random_split(dataset: ~stable_pretraining.data.datasets.Dataset, lengths: ~collections.abc.Sequence[int | float], generator: ~torch._C.Generator | None = <torch._C.Generator object>) list[Subset][source]#

Randomly split a dataset into non-overlapping new datasets of given lengths.

If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.

After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.

Optionally fix the generator for reproducible results, e.g.:

Example

>>> # xdoctest: +SKIP
>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
Parameters:
  • dataset (Dataset) – Dataset to be split

  • lengths (sequence) – lengths or fractions of splits to be produced

  • generator (Generator) – Generator used for the random permutation.

Module contents#

Data module for stable-pretraining.

This module provides dataset utilities, data loading, transformations, and other data-related functionality for the stable-pretraining framework.

class stable_pretraining.data.Categorical(values: list | Tensor, probabilities: list | Tensor)[source]#

Bases: Module

Categorical distribution for sampling discrete values with given probabilities.

sample(*args, **kwargs)[source]#
class stable_pretraining.data.Collator(G_from=None)[source]#

Bases: object

Custom collate function that optionally builds an affinity (or “graph”) matrix based on a specified field.

class stable_pretraining.data.DataModule(train: dict | DictConfig | DataLoader | None = None, test: dict | DictConfig | DataLoader | None = None, val: dict | DictConfig | DataLoader | None = None, predict: dict | DictConfig | DataLoader | None = None, **kwargs)[source]#

Bases: LightningDataModule

PyTorch Lightning DataModule for handling train/val/test/predict dataloaders.

load_state_dict(state_dict)[source]#

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters:

state_dict – the datamodule state returned by state_dict.

predict_dataloader()[source]#

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

set_pl_trainer(trainer: Trainer)[source]#
setup(stage)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
state_dict()[source]#

Called when saving a checkpoint, implement to generate and save datamodule state.

Returns:

A dictionary containing datamodule state.

teardown(stage: str)[source]#

Called at the end of fit (train + validate), validate, test, or predict.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

test_dataloader()[source]#

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader()[source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

class stable_pretraining.data.Dataset(transform=None)[source]#

Bases: Dataset

Base dataset class with transform support and PyTorch Lightning integration.

process_sample(sample, **kwargs)[source]#
set_pl_trainer(trainer: Trainer)[source]#
class stable_pretraining.data.ExponentialMixtureNoiseModel(rates, prior, upper_bound=inf)[source]#

Bases: Module

Exponential mixture noise model for data augmentation or sampling.

sample(*args, **kwargs)[source]#
class stable_pretraining.data.ExponentialNormalNoiseModel(rate, mean, std, prior, upper_bound=inf)[source]#

Bases: Module

Exponential-normal noise model combining exponential and normal distributions.

sample(*args, **kwargs)[source]#
class stable_pretraining.data.FromTorchDataset(dataset, names, transform=None, add_sample_idx=True)[source]#

Bases: Dataset

Wrapper for PyTorch datasets with custom column naming and transforms.

Parameters:
  • dataset – PyTorch dataset to wrap

  • names – List of names for each element returned by the dataset

  • transform – Optional transform to apply to samples

  • add_sample_idx – If True, automatically adds ‘sample_idx’ field to each sample

property column_names#
class stable_pretraining.data.GMM(num_components=5, num_samples=100, dim=2)[source]#

Bases: Dataset

Gaussian Mixture Model dataset for synthetic data generation.

score(samples)[source]#
class stable_pretraining.data.HFDataset(*args, transform=None, rename_columns=None, remove_columns=None, **kwargs)[source]#

Bases: Dataset

Hugging Face dataset wrapper with transform and column manipulation support.

property column_names#
is_saved_with_save_to_disk(path)[source]#
class stable_pretraining.data.MinariEpisodeDataset(dataset)[source]#

Bases: Dataset

Dataset for Minari reinforcement learning data with episode-based access.

NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
property column_names#
nested_step(value, idx)[source]#
set_pl_trainer(trainer)[source]#
class stable_pretraining.data.MinariStepsDataset(dataset, num_steps=2, transform=None)[source]#

Bases: Dataset

Dataset for Minari reinforcement learning data with step-based access.

NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
property column_names#
nested_step(value, idx)[source]#
class stable_pretraining.data.RandomBatchSampler(batch_size: int, length_or_dataset: Dataset | int, *args, **kwargs)[source]#

Bases: Sampler[List[int]]

Wraps another sampler to yield a mini-batch of indices.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class stable_pretraining.data.RepeatedRandomSampler(data_source_or_len: int | Iterable, n_views: int = 1, replacement: bool = False, seed: int = 0, pass_view_idx: bool = False)[source]#

Bases: DistributedSampler

Sampler that repeats each dataset index consecutively for multi-view learning.

IMPORTANT: This sampler repeats each index n_views times in a row, creating sequences like [0,0,0,0, 1,1,1,1, 2,2,2,2, …] for n_views=4. This means: - The DataLoader will load the SAME image multiple times consecutively - Each repeated index goes through the transform pipeline separately - BATCH SIZE: The batch_size in DataLoader refers to total augmented samples.

For example: batch_size=128 with n_views=8 means only 16 unique images, each appearing 8 times with different augmentations

Designed to work with RoundRobinMultiViewTransform which uses a counter to apply different augmentations to each repeated occurrence of the same image.

Example behavior with n_views=3:

Dataset indices: [0, 1, 2, 3, 4] Sampler output: [0,0,0, 1,1,1, 2,2,2, 3,3,3, 4,4,4]

Parameters:
  • data_source (Dataset) – dataset to sample from

  • n_views (int) – number of times to repeat each index consecutively, default=1

  • replacement (bool) – samples are drawn on-demand with replacement if True, default=``False``

  • seed (int) – random seed for shuffling

  • pass_view_idx (bool) – whether to pass the view index to the dataset getitem

Note: For an alternative approach that loads each image once, consider using MultiViewTransform with a standard sampler.

class stable_pretraining.data.Subset(dataset: Dataset, indices: Sequence[int])[source]#

Bases: Dataset

Subset of a dataset at specified indices.

Parameters:
  • dataset (Dataset) – The whole Dataset

  • indices (sequence) – Indices in the whole set selected for subset

property column_names#
dataset: Dataset#
indices: Sequence[int]#
class stable_pretraining.data.SupervisedBatchSampler(batch_size: int, n_views: int, targets_or_dataset: Dataset | list, *args, **kwargs)[source]#

Bases: Sampler[List[int]]

Wraps another sampler to yield a mini-batch of indices.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
stable_pretraining.data.bulk_download(urls: Iterable[str], dest_folder: str | Path, backend: str = 'filesystem', cache_dir: str = '~/.stable_pretraining/')[source]#

Download multiple files concurrently.

Example

import stable_pretraining stable_pretraining.data.bulk_download([

], “todelete”)

Parameters:
  • urls (Iterable[str]) – List of URLs to download

  • dest_folder (Union[str, Path]) – Destination folder for downloads

  • backend (str, optional) – Storage backend type. Defaults to “filesystem”.

  • cache_dir (str, optional) – Cache directory path. Defaults to “~/.stable_pretraining/”.

stable_pretraining.data.download(url, dest_folder, backend='filesystem', cache_dir='~/.stable_pretraining/', progress_bar=True, _progress_dict=None, _task_id=None)[source]#

Download a file from a URL with progress tracking.

Parameters:
  • url – URL to download from

  • dest_folder – Destination folder for the download

  • backend – Storage backend type

  • cache_dir – Cache directory path

  • progress_bar – Whether to show progress bar

  • _progress_dict – Internal dictionary for progress tracking

  • _task_id – Internal task ID for bulk downloads

Returns:

Path to the downloaded file or None if download failed

stable_pretraining.data.fold_views(tensor, idx)[source]#

Fold a tensor containing multiple views back into separate views.

Parameters:
  • tensor – Tensor containing concatenated views

  • idx – Sample indices to determine view boundaries

Returns:

Tuple of tensors, one for each view

stable_pretraining.data.generate_perlin_noise_2d(shape, res, octaves=1, persistence=0.5, lacunarity=2.0)[source]#

Generate 2D Perlin noise.

Parameters:
  • shape – Output shape (height, width)

  • res – Resolution tuple

  • octaves – Number of octaves for fractal noise

  • persistence – Amplitude multiplier for each octave

  • lacunarity – Frequency multiplier for each octave

Returns:

2D tensor of Perlin noise

stable_pretraining.data.perlin_noise_3d(x, y, z)[source]#

Generate 3D Perlin noise at given coordinates.

Parameters:
  • x – X coordinate for noise generation

  • y – Y coordinate for noise generation

  • z – Z coordinate for noise generation

Returns:

Perlin noise value at the given coordinates

stable_pretraining.data.random_split(dataset: ~stable_pretraining.data.datasets.Dataset, lengths: ~collections.abc.Sequence[int | float], generator: ~torch._C.Generator | None = <torch._C.Generator object>) list[Subset][source]#

Randomly split a dataset into non-overlapping new datasets of given lengths.

If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.

After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.

Optionally fix the generator for reproducible results, e.g.:

Example

>>> # xdoctest: +SKIP
>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
Parameters:
  • dataset (Dataset) – Dataset to be split

  • lengths (sequence) – lengths or fractions of splits to be produced

  • generator (Generator) – Generator used for the random permutation.

stable_pretraining.data.swiss_roll(N, margin=1, sampler_time=Uniform(low: 0.10000000149011612, high: 3.0), sampler_width=Uniform(low: 0.0, high: 1.0))[source]#

Generate Swiss Roll dataset points.

Parameters:
  • N – Number of points to generate

  • margin – Margin parameter for the roll

  • sampler_time – Distribution for sampling time parameter

  • sampler_width – Distribution for sampling width parameter

Returns:

Tensor of shape (N, 3) containing Swiss Roll points