stable_pretraining.data#

The data module provides comprehensive tools for dataset handling, transforms, sampling, and data loading in self-supervised learning contexts.

Core Components#

DataModule([train, test, val, predict])

PyTorch Lightning DataModule for handling train/val/test/predict dataloaders.

Collator([G_from])

Custom collate function that optionally builds an affinity (or “graph”) matrix based on a specified field.

Dataset([transform])

Base dataset class with transform support and PyTorch Lightning integration.

Real Data Wrappers#

FromTorchDataset(dataset, names[, ...])

Wrapper for PyTorch datasets with custom column naming and transforms.

HFDataset(*args[, transform, ...])

Hugging Face dataset wrapper with transform and column manipulation support.

Subset(dataset, indices)

Subset of a dataset at specified indices.

Synthetic Data Generators#

GMM([num_components, num_samples, dim])

Gaussian Mixture Model dataset for synthetic data generation.

MinariStepsDataset(dataset[, num_steps, ...])

Dataset for Minari reinforcement learning data with step-based access.

MinariEpisodeDataset(dataset)

Dataset for Minari reinforcement learning data with episode-based access.

swiss_roll(N[, margin, sampler_time, ...])

Generate Swiss Roll dataset points.

generate_perlin_noise_2d(shape, res[, ...])

Generate 2D Perlin noise.

perlin_noise_3d(x, y, z)

Generate 3D Perlin noise at given coordinates.

Noise Models#

Categorical(values, probabilities)

Categorical distribution for sampling discrete values with given probabilities.

ExponentialMixtureNoiseModel(rates, prior[, ...])

Exponential mixture noise model for data augmentation or sampling.

ExponentialNormalNoiseModel(rate, mean, std, ...)

Exponential-normal noise model combining exponential and normal distributions.

Samplers#

RepeatedRandomSampler(data_source_or_len[, ...])

Samples elements randomly.

SupervisedBatchSampler(batch_size, n_views, ...)

Wraps another sampler to yield a mini-batch of indices.

RandomBatchSampler(batch_size, ...)

Wraps another sampler to yield a mini-batch of indices.

Utility Functions#

fold_views(tensor, idx)

Fold a tensor containing multiple views back into separate views.

random_split(dataset, lengths[, generator])

Randomly split a dataset into non-overlapping new datasets of given lengths.

download(url, dest_folder[, backend, ...])

Download a file from a URL with progress tracking.

bulk_download(urls, dest_folder[, backend, ...])

Download multiple files concurrently.

Modules#

transforms

dataset_stats

Dataset statistics for normalization.

synthetic_data

Synthetic and simulated data generators.