stable_pretraining.data package#
Submodules#
stable_pretraining.data.collate module#
stable_pretraining.data.dataset_stats module#
Dataset statistics for normalization.
This module contains pre-computed mean and standard deviation values for various common datasets, used for data normalization during preprocessing.
stable_pretraining.data.datasets module#
Dataset classes for real data sources.
This module provides dataset wrappers and utilities for working with real data sources including PyTorch datasets, HuggingFace datasets, and dataset subsets.
- class stable_pretraining.data.datasets.Dataset(transform=None)[source]#
Bases:
DatasetBase dataset class with transform support and PyTorch Lightning integration.
- class stable_pretraining.data.datasets.FromTorchDataset(dataset, names, transform=None, add_sample_idx=True)[source]#
Bases:
DatasetWrapper for PyTorch datasets with custom column naming and transforms.
- Parameters:
dataset – PyTorch dataset to wrap
names – List of names for each element returned by the dataset
transform – Optional transform to apply to samples
add_sample_idx – If True, automatically adds ‘sample_idx’ field to each sample
- property column_names#
- class stable_pretraining.data.datasets.HFDataset(*args, transform=None, rename_columns=None, remove_columns=None, **kwargs)[source]#
Bases:
DatasetHugging Face dataset wrapper with transform and column manipulation support.
- property column_names#
stable_pretraining.data.download module#
Download utilities for fetching datasets and model weights.
This module provides functions for downloading files from URLs with progress tracking, caching, and concurrent download support.
- stable_pretraining.data.download.bulk_download(urls: Iterable[str], dest_folder: str | Path, backend: str = 'filesystem', cache_dir: str = '~/.stable_pretraining/')[source]#
Download multiple files concurrently.
Example
import stable_pretraining stable_pretraining.data.bulk_download([
], “todelete”)
- stable_pretraining.data.download.download(url, dest_folder, backend='filesystem', cache_dir='~/.stable_pretraining/', progress_bar=True, _progress_dict=None, _task_id=None)[source]#
Download a file from a URL with progress tracking.
- Parameters:
url – URL to download from
dest_folder – Destination folder for the download
backend – Storage backend type
cache_dir – Cache directory path
progress_bar – Whether to show progress bar
_progress_dict – Internal dictionary for progress tracking
_task_id – Internal task ID for bulk downloads
- Returns:
Path to the downloaded file or None if download failed
stable_pretraining.data.masking module#
- stable_pretraining.data.masking.multi_block_mask(height: int, width: int, block_scales: list[tuple[float, float]] = [(0.85, 1.0), (0.15, 0.2), (0.15, 0.2), (0.15, 0.2), (0.15, 0.2)], aspect_ratios: list[tuple[float, float]] = [(1.0, 1.0), (0.75, 1.5), (0.75, 1.5), (0.75, 1.5), (0.75, 1.5)], min_keep: int = 1, seed: int = 0) list[Tensor, ...][source]#
Generates a list of distinct, randomly placed, block-shaped binary masks.
This function creates a series of block masks based on provided scale and aspect ratio specifications. For each pair of (scale, aspect_ratio), it first samples a block size (height, width). It then places this block at a random location within the grid of the specified height and width.
The process is repeated for all items in the input lists to produce a list of independent masks.
Example
>>> # xdoctest: +SKIP >>> # Generate masks for a 14x14 patch grid >>> masks = multi_block_mask(height=14, width=14) >>> len(masks) 5 >>> masks[0].nonzero().size(0) 169 >>> masks[1].nonzero().size(0) 30
- Parameters:
height (int) – The height of the grid to generate masks for (in patches).
width (int) – The width of the grid to generate masks for (in patches).
block_scales (list[tuple[float, float]]) – A list of tuples, each defining the min/max scale of a block’s area relative to the total grid area.
aspect_ratios (list[tuple[float, float]]) – A list of tuples, each defining the min/max aspect ratio (height/width) for a corresponding block.
min_keep (int) – The minimum number of `1`s required for a valid block mask.
seed (int) – A seed for the random number generator to ensure reproducibility.
- Returns:
- Return type:
stable_pretraining.data.module module#
- class stable_pretraining.data.module.DataModule(train: dict | DictConfig | DataLoader | None = None, test: dict | DictConfig | DataLoader | None = None, val: dict | DictConfig | DataLoader | None = None, predict: dict | DictConfig | DataLoader | None = None, **kwargs)[source]#
Bases:
LightningDataModulePyTorch Lightning DataModule for handling train/val/test/predict dataloaders.
- load_state_dict(state_dict)[source]#
Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
- Parameters:
state_dict – the datamodule state returned by
state_dict.
- predict_dataloader()[source]#
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data().predict()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying prediction samples.
- setup(stage)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- state_dict()[source]#
Called when saving a checkpoint, implement to generate and save datamodule state.
- Returns:
A dictionary containing datamodule state.
- teardown(stage: str)[source]#
Called at the end of fit (train + validate), validate, test, or predict.
- Parameters:
stage – either
'fit','validate','test', or'predict'
- test_dataloader()[source]#
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- train_dataloader()[source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
stable_pretraining.data.sampler module#
- class stable_pretraining.data.sampler.RandomBatchSampler(batch_size: int, length_or_dataset: Dataset | int, *args, **kwargs)[source]#
-
Wraps another sampler to yield a mini-batch of indices.
- Parameters:
Example
>>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False ... ) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- class stable_pretraining.data.sampler.RepeatedRandomSampler(data_source_or_len: int | Iterable, n_views: int = 1, replacement: bool = False, seed: int = 0, pass_view_idx: bool = False)[source]#
Bases:
DistributedSamplerSampler that repeats each dataset index consecutively for multi-view learning.
IMPORTANT: This sampler repeats each index n_views times in a row, creating sequences like [0,0,0,0, 1,1,1,1, 2,2,2,2, …] for n_views=4. This means: - The DataLoader will load the SAME image multiple times consecutively - Each repeated index goes through the transform pipeline separately - BATCH SIZE: The batch_size in DataLoader refers to total augmented samples.
For example: batch_size=128 with n_views=8 means only 16 unique images, each appearing 8 times with different augmentations
Designed to work with RoundRobinMultiViewTransform which uses a counter to apply different augmentations to each repeated occurrence of the same image.
- Example behavior with n_views=3:
Dataset indices: [0, 1, 2, 3, 4] Sampler output: [0,0,0, 1,1,1, 2,2,2, 3,3,3, 4,4,4]
- Parameters:
data_source (Dataset) – dataset to sample from
n_views (int) – number of times to repeat each index consecutively, default=1
replacement (bool) – samples are drawn on-demand with replacement if
True, default=``False``seed (int) – random seed for shuffling
pass_view_idx (bool) – whether to pass the view index to the dataset getitem
Note: For an alternative approach that loads each image once, consider using MultiViewTransform with a standard sampler.
- class stable_pretraining.data.sampler.SupervisedBatchSampler(batch_size: int, n_views: int, targets_or_dataset: Dataset | list, *args, **kwargs)[source]#
-
Wraps another sampler to yield a mini-batch of indices.
- Parameters:
Example
>>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False ... ) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
stable_pretraining.data.synthetic_data module#
Synthetic and simulated data generators.
This module provides various synthetic data generators including manifold datasets, noise generators, statistical models, and simulated environments for testing and experimentation purposes.
- class stable_pretraining.data.synthetic_data.Categorical(values: list | Tensor, probabilities: list | Tensor)[source]#
Bases:
ModuleCategorical distribution for sampling discrete values with given probabilities.
- class stable_pretraining.data.synthetic_data.ExponentialMixtureNoiseModel(rates, prior, upper_bound=inf)[source]#
Bases:
ModuleExponential mixture noise model for data augmentation or sampling.
- class stable_pretraining.data.synthetic_data.ExponentialNormalNoiseModel(rate, mean, std, prior, upper_bound=inf)[source]#
Bases:
ModuleExponential-normal noise model combining exponential and normal distributions.
- class stable_pretraining.data.synthetic_data.GMM(num_components=5, num_samples=100, dim=2)[source]#
Bases:
DatasetGaussian Mixture Model dataset for synthetic data generation.
- class stable_pretraining.data.synthetic_data.MinariEpisodeDataset(dataset)[source]#
Bases:
DatasetDataset for Minari reinforcement learning data with episode-based access.
- NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
- property column_names#
- class stable_pretraining.data.synthetic_data.MinariStepsDataset(dataset, num_steps=2, transform=None)[source]#
Bases:
DatasetDataset for Minari reinforcement learning data with step-based access.
- NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
- property column_names#
- stable_pretraining.data.synthetic_data.generate_perlin_noise_2d(shape, res, octaves=1, persistence=0.5, lacunarity=2.0)[source]#
Generate 2D Perlin noise.
- Parameters:
shape – Output shape (height, width)
res – Resolution tuple
octaves – Number of octaves for fractal noise
persistence – Amplitude multiplier for each octave
lacunarity – Frequency multiplier for each octave
- Returns:
2D tensor of Perlin noise
- stable_pretraining.data.synthetic_data.perlin_noise_3d(x, y, z)[source]#
Generate 3D Perlin noise at given coordinates.
- Parameters:
x – X coordinate for noise generation
y – Y coordinate for noise generation
z – Z coordinate for noise generation
- Returns:
Perlin noise value at the given coordinates
- stable_pretraining.data.synthetic_data.swiss_roll(N, margin=1, sampler_time=Uniform(low: 0.10000000149011612, high: 3.0), sampler_width=Uniform(low: 0.0, high: 1.0))[source]#
Generate Swiss Roll dataset points.
- Parameters:
N – Number of points to generate
margin – Margin parameter for the roll
sampler_time – Distribution for sampling time parameter
sampler_width – Distribution for sampling width parameter
- Returns:
Tensor of shape (N, 3) containing Swiss Roll points
stable_pretraining.data.transforms module#
- class stable_pretraining.data.transforms.AdditiveGaussian(sigma, p=1)[source]#
Bases:
TransformAdd Gaussian noise to input data.
- BYPASS_VALUE = False#
- class stable_pretraining.data.transforms.CenterCrop(size, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,CenterCropCrop the center of an image to the given size.
- class stable_pretraining.data.transforms.ColorJitter(brightness=None, contrast=None, saturation=None, hue=None, p=1, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,ColorJitterRandomly change brightness, contrast, saturation, and hue of an image.
- static get_params(brightness: list[float] | None, contrast: list[float] | None, saturation: list[float] | None, hue: list[float] | None) tuple[Tensor, float | None, float | None, float | None, float | None][source]#
Get the parameters for the randomized transform to be applied on image.
- Parameters:
brightness (tuple of float (min, max), optional) – The range from which the brightness_factor is chosen uniformly. Pass None to turn off the transformation.
contrast (tuple of float (min, max), optional) – The range from which the contrast_factor is chosen uniformly. Pass None to turn off the transformation.
saturation (tuple of float (min, max), optional) – The range from which the saturation_factor is chosen uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional) – The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation.
- Returns:
The parameters used to apply the randomized transform along with their random order.
- Return type:
- class stable_pretraining.data.transforms.Compose(*args)[source]#
Bases:
TransformCompose multiple transforms together in sequence.
- class stable_pretraining.data.transforms.Conditional(transform, condition_key, apply_on_true=True)[source]#
Bases:
TransformApply transform conditionally based on a data dictionary key.
- class stable_pretraining.data.transforms.ContextTargetsMultiBlockMask(patch_size=16, context_scale=(0.85, 1.0), context_aspect_ratio=(1.0, 1.0), target_scales=((0.15, 0.2), (0.15, 0.2), (0.15, 0.2), (0.15, 0.2)), target_aspect_ratios=((0.75, 1.5), (0.75, 1.5), (0.75, 1.5), (0.75, 1.5)), min_keep=10, source: str = 'image', target_context: str = 'mask_context', target_targets: str = 'masks_target')[source]#
Bases:
TransformTransform that adds multi-block masks to batch, with multiple target blocks and one disjoint context block.
- Parameters:
patch_size – Size of the patch in patches
num_blocks – Number of blocks to sample
context_scale – Scale of the context block
aspect_ratio – Aspect ratio of the blocks
min_keep – Minimum number of patches that must be in the block
- class stable_pretraining.data.transforms.ControlledTransform(transform: callable, seed_offset: int = 0, key: str | None = 'idx')[source]#
Bases:
TransformFace Landmarks dataset.
- class stable_pretraining.data.transforms.GaussianBlur(kernel_size, sigma=(0.1, 2.0), p=1, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,GaussianBlurApply Gaussian blur to image with random sigma values.
- class stable_pretraining.data.transforms.Lambda(lambd, source: str = 'image', target: str = 'image')[source]#
Bases:
TransformApplies a lambda callable to target key and store it in source.
- class stable_pretraining.data.transforms.MultiViewTransform(transforms)[source]#
Bases:
TransformCreates multiple views from one sample by applying different transforms.
Takes a single sample and applies different transforms to create multiple views, returning a list of complete sample dicts. Preserves all modifications each transform makes (masks, augmentation params, metadata, etc.).
- Implementation Note:
This transform uses shallow copy (dict.copy()) for the input sample before applying each transform. This is efficient and safe because: - The shallow copy shares references to the original tensors/objects - Standard transforms create NEW tensors (e.g., through mul(), resize(),
crop()) rather than modifying inputs in-place
The original sample remains unchanged
- Consequences of shallow copy:
Memory efficient: Original tensors are not duplicated unnecessarily
Safe with torchvision transforms: All torchvision transforms and our custom transforms follow the pattern of creating new tensors
Caution: If using custom transforms that modify tensors in-place (using operations like mul_(), add_() with underscore), views may interfere with each other. Always use non-in-place operations in custom transforms.
- Parameters:
transforms – Either a list or dict of transforms. - List: Returns a list of views in the same order - Dict: Returns a dict of views with the same keys
- Returns:
If transforms is a list: Returns a list of transformed sample dicts
If transforms is a dict: Returns a dict of transformed sample dicts with same keys
Each dict contains NEW tensors, not references to the original.
- Return type:
Example
# List input - returns list of views transform = MultiViewTransform([
strong_augmentation, # Creates first view with strong aug weak_augmentation, # Creates second view with weak aug
]) # Input: {“image”: img, “label”: 0} # Output: [{“image”: img_strong, “label”: 0}, {“image”: img_weak, “label”: 0}]
# Dict input - returns dict of named views transform = MultiViewTransform({
“student”: strong_augmentation, “teacher”: weak_augmentation,
}) # Input: {“image”: img, “label”: 0} # Output: {“student”: {“image”: img_strong, “label”: 0}, # “teacher”: {“image”: img_weak, “label”: 0}}
- class stable_pretraining.data.transforms.PILGaussianBlur(sigma=None, p=1, source: str = 'image', target: str = 'image')[source]#
Bases:
TransformPIL-based Gaussian blur transform with random sigma sampling.
- class stable_pretraining.data.transforms.PatchMasking(patch_size: int = 16, drop_ratio: float = 0.5, source: str = 'image', target: str = 'image', fill_value: float = None, mask_key: str = 'patch_mask')[source]#
Bases:
TransformRandomly masks square patches in an image, similar to patch masking used in Masked Signal Encoding (MSE) tasks.
This transform operates on a dictionary input, applies patch masking to the image found at the specified source key, and writes the masked image to the target key. It also saves a boolean mask matrix (one entry per patch) to the mask_key in the dictionary, indicating which patches were masked (False) or kept (True). The output image remains in the same format as the input (PIL Image or Tensor), and the masking is performed efficiently for both input types.
- Parameters:
patch_size (int) – The size (in pixels) of each square patch to be masked.
drop_ratio (float) – The exact fraction of patches to randomly mask (set to the mask value).
source (str) – The key in the input dictionary from which to read the image.
target (str) – The key in the output dictionary to which the masked image will be written.
mask_key (str) – The key in the output dictionary to which the boolean patch mask will be written.
fill_value (float, optional) – The value to use for masked patches. If None, defaults to 0.0 for float tensors, and 128/255.0 for PIL images (mid-gray). Can be set to any float in [0,1] for normalized images.
- class stable_pretraining.data.transforms.RGB(source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RGBConvert image to RGB format.
- class stable_pretraining.data.transforms.RandomChannelPermutation(source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RandomChannelPermutationRandomly permute the channels of an image.
- class stable_pretraining.data.transforms.RandomContiguousTemporalSampler(source, target, num_frames, frame_subsampling: int = 1)[source]#
Bases:
TransformRandomly sample contiguous frames from a video sequence.
- class stable_pretraining.data.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant', source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RandomCropCrop a random portion of image and resize it to given size.
- class stable_pretraining.data.transforms.RandomGrayscale(p=0.1, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RandomGrayscaleRandomly convert image to grayscale with given probability.
- class stable_pretraining.data.transforms.RandomHorizontalFlip(p=0.5, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RandomHorizontalFlipHorizontally flip the given image randomly with a given probability.
- class stable_pretraining.data.transforms.RandomMask(patch_size=16, mask_ratio=0.75, source: str = 'image', target_visible: str = 'mask_visible', target_masked: str = 'mask_masked', target_ids_restore: str = 'ids_restore', target_len_keep: str = 'len_keep')[source]#
Bases:
TransformCreates a random MAE-style mask for an image.
This transform generates a random permutation of all patch indices for an input image. It then splits these indices into two disjoint sets: ‘visible’ and ‘masked’, according to the specified mask_ratio.
It also provides an ids_restore tensor, which can un-shuffle a sequence of patches back to its original 2D grid order. All outputs are added as new keys to the sample dictionary.
Example
>>> # xdoctest: +SKIP >>> transform = RandomMask(patch_size=16, mask_ratio=0.75) >>> sample = {"image": torch.randn(3, 224, 224)} >>> result = transform(sample) >>> sorted(result.keys()) ['image', 'ids_restore', 'len_keep', 'mask_masked', 'mask_visible'] >>> result["len_keep"] 49 >>> result["mask_visible"].shape torch.Size([49])
- Parameters:
patch_size (int) – The height and width of each square patch.
mask_ratio (float) – The fraction of patches to be masked (e.g., 0.75).
source (str) – The key in the sample dict for the source image tensor.
target_visible (str) – The key to use when storing visible patch indices.
target_masked (str) – The key to use when storing masked patch indices.
target_ids_restore (str) – The key to use for the restoration indices.
target_len_keep (str) – The key to use for the count of visible patches.
- class stable_pretraining.data.transforms.RandomResizedCrop(size: int | Sequence[int], scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (0.75, 1.3333333333333333), interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, antialias: bool | None = True, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RandomResizedCropCrop a random portion of image and resize it to given size.
- class stable_pretraining.data.transforms.RandomRotation(degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RandomRotationRotate image by random angle within specified degrees range.
- class stable_pretraining.data.transforms.RandomSolarize(threshold, p=0.5, source: str = 'image', target: str = 'image')[source]#
Bases:
Transform,RandomSolarizeRandomly solarize image by inverting pixel values above threshold.
- class stable_pretraining.data.transforms.Resize(size, interpolation=2, max_size=None, antialias=True, source='image', target='image')[source]#
Bases:
Transform,ResizeResize image to specified size.
- class stable_pretraining.data.transforms.RoundRobinMultiViewTransform(transforms)[source]#
Bases:
TransformRound-robin multi-view transform that cycles through transforms using a counter.
IMPORTANT: This transform is designed to work with RepeatedRandomSampler, where each image index appears multiple times consecutively in the batch. It uses an internal counter to apply different augmentations to each repeated occurrence.
BATCH SIZE NOTE: When using this with RepeatedRandomSampler, the batch_size parameter refers to the total number of augmented samples, NOT the number of unique images. For example, with batch_size=256 and n_views=2, you get 128 unique images, each appearing twice with different augmentations.
How it works: 1. RepeatedRandomSampler produces indices like [0,0,1,1,2,2,…] (for n_views=2) 2. DataLoader loads the same image multiple times 3. This transform applies a different augmentation each time using round-robin
- Parameters:
transforms – List of transforms, one for each view. The counter cycles through these transforms in order.
Example
# With RepeatedRandomSampler(dataset, n_views=2) transform = RoundRobinMultiViewTransform([
strong_augmentation, # Applied to 1st occurrence of each image weak_augmentation, # Applied to 2nd occurrence of each image
])
Warning: The internal counter makes this transform stateful and not thread-safe.
- class stable_pretraining.data.transforms.RoutingTransform(router: callable, transforms: list | tuple | dict)[source]#
Bases:
TransformApplies a routing callable to conditionally apply a transform from many candidates.
- class stable_pretraining.data.transforms.ToImage(dtype=torch.float32, scale=True, mean=None, std=None, source: str = 'image', target: str = 'image')[source]#
Bases:
TransformConvert input to image tensor with optional normalization.
- class stable_pretraining.data.transforms.Transform[source]#
Bases:
TransformBase transform class extending torchvision v2.Transform with nested data handling.
- property name#
- class stable_pretraining.data.transforms.UniformTemporalSubsample(num_samples: int, temporal_dim: int = -3, source: str = 'video', target: str = 'video')[source]#
Bases:
Transformnn.Modulewrapper forpytorchvideo.transforms.functional.uniform_temporal_subsample.
stable_pretraining.data.utils module#
Utility functions for data manipulation and processing.
This module provides utility functions for working with datasets, including view folding for contrastive learning and dataset splitting.
- stable_pretraining.data.utils.apply_masks(x: Tensor, *masks: Tensor) Tensor[source]#
Apply one or more masks to a batch of patched images.
This function is generalized to accept any number of mask tensors. If a single mask is provided, the output shape is [B, K, D]. If M masks are provided, the function creates M masked views and concatenates them along the batch dimension, resulting in an output of shape [B*M, K, D].
Example
>>> # xdoctest: +SKIP >>> x = torch.randn(4, 196, 128) >>> mask1 = torch.randint(0, 196, (4, 50)) >>> mask2 = torch.randint(0, 196, (4, 50)) >>> # Single mask case >>> single_view = apply_masks(x, mask1) >>> single_view.shape torch.Size([4, 50, 128]) >>> # Multi-mask case >>> multi_view = apply_masks(x, mask1, mask2) >>> multi_view.shape torch.Size([8, 50, 128])
- Parameters:
x (torch.Tensor) – Input tensor of patches with shape [B, N, D].
*masks (torch.Tensor) – A variable number of mask tensors, each a tensor of indices with shape [B, K].
- Returns:
- The tensor of selected patches. The shape will be
[B, K, D] for a single mask, or [B*M, K, D] for M masks.
- Return type:
- Raises:
ValueError – If no masks are provided.
- stable_pretraining.data.utils.fold_views(tensor, idx)[source]#
Fold a tensor containing multiple views back into separate views.
- Parameters:
tensor – Tensor containing concatenated views
idx – Sample indices to determine view boundaries
- Returns:
Tuple of tensors, one for each view
- stable_pretraining.data.utils.get_num_workers()[source]#
Automatically determine the optimal number of DataLoader workers.
This function computes the ideal number of worker processes for PyTorch DataLoaders based on available CPU resources and distributed training configuration. It provides a zero-configuration approach that works reliably across different environments.
- The calculation logic:
Detect CPUs available to this process (respects affinity/cgroups)
Divide by world_size if using DDP (each rank spawns its own workers)
Return the result (always >= 1)
- Returns:
Number of DataLoader workers to use. Minimum value is 1.
- Return type:
Notes
Uses os.sched_getaffinity(0) on Linux to respect CPU affinity masks set by job schedulers (SLURM), containers (Docker), or taskset.
Falls back to os.cpu_count() on macOS/Windows.
In DDP mode, automatically divides by world_size since each process independently spawns workers.
Should be called AFTER distributed initialization for accurate results.
Examples
>>> # Simple usage >>> num_workers = get_num_workers() >>> loader = DataLoader(dataset, num_workers=num_workers)
>>> # In a Lightning DataModule >>> class MyDataModule(L.LightningDataModule): ... def train_dataloader(self): ... return DataLoader( ... self.train_dataset, ... num_workers=get_num_workers(), ... shuffle=True, ... )
>>> # Example outputs: >>> # - 16 CPUs, single GPU: returns 16 >>> # - 32 CPUs, 4 GPUs (DDP): returns 8 per GPU >>> # - 8 CPUs, 8 GPUs (DDP): returns 1 per GPU
See also
PyTorch DataLoader: https://pytorch.org/docs/stable/data.html
CPU affinity: https://man7.org/linux/man-pages/man2/sched_setaffinity.2.html
- stable_pretraining.data.utils.random_split(dataset: ~stable_pretraining.data.datasets.Dataset, lengths: ~collections.abc.Sequence[int | float], generator: ~torch._C.Generator | None = <torch._C.Generator object>) list[Subset][source]#
Randomly split a dataset into non-overlapping new datasets of given lengths.
If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.
Optionally fix the generator for reproducible results, e.g.:
Example
>>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- Parameters:
dataset (Dataset) – Dataset to be split
lengths (sequence) – lengths or fractions of splits to be produced
generator (Generator) – Generator used for the random permutation.
Module contents#
Data module for stable-pretraining.
This module provides dataset utilities, data loading, transformations, and other data-related functionality for the stable-pretraining framework.
- class stable_pretraining.data.Categorical(values: list | Tensor, probabilities: list | Tensor)[source]#
Bases:
ModuleCategorical distribution for sampling discrete values with given probabilities.
- class stable_pretraining.data.Collator(G_from=None)[source]#
Bases:
objectCustom collate function that optionally builds an affinity (or “graph”) matrix based on a specified field.
- class stable_pretraining.data.DataModule(train: dict | DictConfig | DataLoader | None = None, test: dict | DictConfig | DataLoader | None = None, val: dict | DictConfig | DataLoader | None = None, predict: dict | DictConfig | DataLoader | None = None, **kwargs)[source]#
Bases:
LightningDataModulePyTorch Lightning DataModule for handling train/val/test/predict dataloaders.
- load_state_dict(state_dict)[source]#
Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
- Parameters:
state_dict – the datamodule state returned by
state_dict.
- predict_dataloader()[source]#
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data().predict()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying prediction samples.
- setup(stage)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- state_dict()[source]#
Called when saving a checkpoint, implement to generate and save datamodule state.
- Returns:
A dictionary containing datamodule state.
- teardown(stage: str)[source]#
Called at the end of fit (train + validate), validate, test, or predict.
- Parameters:
stage – either
'fit','validate','test', or'predict'
- test_dataloader()[source]#
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- train_dataloader()[source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- class stable_pretraining.data.Dataset(transform=None)[source]#
Bases:
DatasetBase dataset class with transform support and PyTorch Lightning integration.
- class stable_pretraining.data.ExponentialMixtureNoiseModel(rates, prior, upper_bound=inf)[source]#
Bases:
ModuleExponential mixture noise model for data augmentation or sampling.
- class stable_pretraining.data.ExponentialNormalNoiseModel(rate, mean, std, prior, upper_bound=inf)[source]#
Bases:
ModuleExponential-normal noise model combining exponential and normal distributions.
- class stable_pretraining.data.FromTorchDataset(dataset, names, transform=None, add_sample_idx=True)[source]#
Bases:
DatasetWrapper for PyTorch datasets with custom column naming and transforms.
- Parameters:
dataset – PyTorch dataset to wrap
names – List of names for each element returned by the dataset
transform – Optional transform to apply to samples
add_sample_idx – If True, automatically adds ‘sample_idx’ field to each sample
- property column_names#
- class stable_pretraining.data.GMM(num_components=5, num_samples=100, dim=2)[source]#
Bases:
DatasetGaussian Mixture Model dataset for synthetic data generation.
- class stable_pretraining.data.HFDataset(*args, transform=None, rename_columns=None, remove_columns=None, **kwargs)[source]#
Bases:
DatasetHugging Face dataset wrapper with transform and column manipulation support.
- property column_names#
- class stable_pretraining.data.MinariEpisodeDataset(dataset)[source]#
Bases:
DatasetDataset for Minari reinforcement learning data with episode-based access.
- NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
- property column_names#
- class stable_pretraining.data.MinariStepsDataset(dataset, num_steps=2, transform=None)[source]#
Bases:
DatasetDataset for Minari reinforcement learning data with step-based access.
- NAMES = ['observations', 'actions', 'rewards', 'terminations', 'truncations']#
- property column_names#
- class stable_pretraining.data.RandomBatchSampler(batch_size: int, length_or_dataset: Dataset | int, *args, **kwargs)[source]#
-
Wraps another sampler to yield a mini-batch of indices.
- Parameters:
Example
>>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False ... ) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- class stable_pretraining.data.RepeatedRandomSampler(data_source_or_len: int | Iterable, n_views: int = 1, replacement: bool = False, seed: int = 0, pass_view_idx: bool = False)[source]#
Bases:
DistributedSamplerSampler that repeats each dataset index consecutively for multi-view learning.
IMPORTANT: This sampler repeats each index n_views times in a row, creating sequences like [0,0,0,0, 1,1,1,1, 2,2,2,2, …] for n_views=4. This means: - The DataLoader will load the SAME image multiple times consecutively - Each repeated index goes through the transform pipeline separately - BATCH SIZE: The batch_size in DataLoader refers to total augmented samples.
For example: batch_size=128 with n_views=8 means only 16 unique images, each appearing 8 times with different augmentations
Designed to work with RoundRobinMultiViewTransform which uses a counter to apply different augmentations to each repeated occurrence of the same image.
- Example behavior with n_views=3:
Dataset indices: [0, 1, 2, 3, 4] Sampler output: [0,0,0, 1,1,1, 2,2,2, 3,3,3, 4,4,4]
- Parameters:
data_source (Dataset) – dataset to sample from
n_views (int) – number of times to repeat each index consecutively, default=1
replacement (bool) – samples are drawn on-demand with replacement if
True, default=``False``seed (int) – random seed for shuffling
pass_view_idx (bool) – whether to pass the view index to the dataset getitem
Note: For an alternative approach that loads each image once, consider using MultiViewTransform with a standard sampler.
- class stable_pretraining.data.Subset(dataset: Dataset, indices: Sequence[int])[source]#
Bases:
DatasetSubset of a dataset at specified indices.
- Parameters:
dataset (Dataset) – The whole Dataset
indices (sequence) – Indices in the whole set selected for subset
- property column_names#
- class stable_pretraining.data.SupervisedBatchSampler(batch_size: int, n_views: int, targets_or_dataset: Dataset | list, *args, **kwargs)[source]#
-
Wraps another sampler to yield a mini-batch of indices.
- Parameters:
Example
>>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False ... ) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- stable_pretraining.data.bulk_download(urls: Iterable[str], dest_folder: str | Path, backend: str = 'filesystem', cache_dir: str = '~/.stable_pretraining/')[source]#
Download multiple files concurrently.
Example
import stable_pretraining stable_pretraining.data.bulk_download([
], “todelete”)
- stable_pretraining.data.download(url, dest_folder, backend='filesystem', cache_dir='~/.stable_pretraining/', progress_bar=True, _progress_dict=None, _task_id=None)[source]#
Download a file from a URL with progress tracking.
- Parameters:
url – URL to download from
dest_folder – Destination folder for the download
backend – Storage backend type
cache_dir – Cache directory path
progress_bar – Whether to show progress bar
_progress_dict – Internal dictionary for progress tracking
_task_id – Internal task ID for bulk downloads
- Returns:
Path to the downloaded file or None if download failed
- stable_pretraining.data.fold_views(tensor, idx)[source]#
Fold a tensor containing multiple views back into separate views.
- Parameters:
tensor – Tensor containing concatenated views
idx – Sample indices to determine view boundaries
- Returns:
Tuple of tensors, one for each view
- stable_pretraining.data.generate_perlin_noise_2d(shape, res, octaves=1, persistence=0.5, lacunarity=2.0)[source]#
Generate 2D Perlin noise.
- Parameters:
shape – Output shape (height, width)
res – Resolution tuple
octaves – Number of octaves for fractal noise
persistence – Amplitude multiplier for each octave
lacunarity – Frequency multiplier for each octave
- Returns:
2D tensor of Perlin noise
- stable_pretraining.data.perlin_noise_3d(x, y, z)[source]#
Generate 3D Perlin noise at given coordinates.
- Parameters:
x – X coordinate for noise generation
y – Y coordinate for noise generation
z – Z coordinate for noise generation
- Returns:
Perlin noise value at the given coordinates
- stable_pretraining.data.random_split(dataset: ~stable_pretraining.data.datasets.Dataset, lengths: ~collections.abc.Sequence[int | float], generator: ~torch._C.Generator | None = <torch._C.Generator object>) list[Subset][source]#
Randomly split a dataset into non-overlapping new datasets of given lengths.
If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.
Optionally fix the generator for reproducible results, e.g.:
Example
>>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- Parameters:
dataset (Dataset) – Dataset to be split
lengths (sequence) – lengths or fractions of splits to be produced
generator (Generator) – Generator used for the random permutation.
- stable_pretraining.data.swiss_roll(N, margin=1, sampler_time=Uniform(low: 0.10000000149011612, high: 3.0), sampler_width=Uniform(low: 0.0, high: 1.0))[source]#
Generate Swiss Roll dataset points.
- Parameters:
N – Number of points to generate
margin – Margin parameter for the roll
sampler_time – Distribution for sampling time parameter
sampler_width – Distribution for sampling width parameter
- Returns:
Tensor of shape (N, 3) containing Swiss Roll points