Source code for stable_ssl.data.utils

import itertools
import math
import multiprocessing
import os
import time
import warnings
from collections.abc import Sequence
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Iterable, Optional, Union, cast
from urllib.parse import urlparse

import lightning as pl
import numpy as np
import rich.progress
import torch
import torch.distributions as dist
from filelock import FileLock
from loguru import logger as logging
from requests_cache import CachedSession
from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)

# No 'default_generator' in torch/__init__.pyi
from torch import Generator, default_generator, randperm
from tqdm import tqdm


[docs] def bulk_download( urls: Iterable[str], dest_folder: Union[str, Path], backend: str = "filesystem", cache_dir: str = "~/.stable_ssl/", ): """Download multiple files concurrently. Example: import stable_ssl stable_ssl.data.bulk_download([ "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz", ], "todelete") Args: urls (Iterable[str]): List of URLs to download dest_folder (Union[str, Path]): Destination folder for downloads backend (str, optional): Storage backend type. Defaults to "filesystem". cache_dir (str, optional): Cache directory path. Defaults to "~/.stable_ssl/". """ num_workers = len(urls) filenames = [os.path.basename(urlparse(url).path) for url in urls] # console = Console(force_terminal=True, force_interactive=False) with rich.progress.Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), MofNCompleteColumn(), TextColumn("•"), TimeElapsedColumn(), TextColumn("•"), TimeRemainingColumn(), refresh_per_second=5, # console=console ) as progress: futures = [] with multiprocessing.Manager() as manager: _progress = manager.dict() # Shared dictionary for progress with ProcessPoolExecutor(max_workers=num_workers) as executor: for i in range(num_workers): # 10 tasks in this example task_id = filenames[i] # Submit tasks and pass the shared dict and task ID future = executor.submit( download, urls[i], dest_folder, backend, cache_dir, False, _progress, task_id, ) futures.append(future) # Create Rich tasks for each process rich_tasks = {} # for future in futures: # # This will block until the task is submitted, returning the task_id # task_id = future.result() # rich_tasks[task_id] = progress.add_task( # f"[green]{task_id}", total=_progress[task_id]["total"] # ) # Update Rich progress bars based on the shared dictionary while not all(future.done() for future in futures): # print(_progress) for task_id in list(_progress.keys()): if task_id in rich_tasks: progress.update( rich_tasks[task_id], completed=_progress[task_id]["progress"], ) else: rich_tasks[task_id] = progress.add_task( f"[green]{task_id}", total=_progress[task_id]["total"], visible=True, ) # for task_id, task_rich_id in rich_tasks.items(): # if ( # _progress[task_id]["progress"] # <= _progress[task_id]["total"] # ): # progress.update( # task_rich_id, completed=_progress[task_id]["progress"] # ) time.sleep(0.01)
# Final update after all tasks are completed # for task_id, task_rich_id in rich_tasks.items(): # progress.update(task_rich_id, completed=_progress[task_id]["total"])
[docs] def download( url, dest_folder, backend="filesystem", cache_dir="~/.stable_ssl/", progress_bar=True, _progress_dict=None, _task_id=None, ): try: filename = os.path.basename(urlparse(url).path) # Ensure the destination folder exists dest_folder = Path(dest_folder) dest_folder.mkdir(exist_ok=True, parents=True) # Get the file name local_filename = dest_folder / filename lock_filename = dest_folder / f"{filename}.lock" # Use a file lock to prevent concurrent downloads with FileLock(lock_filename): # Download the file session = CachedSession(cache_dir, backend=backend) logging.info(f"Downloading: {url}") response = session.head(url) total_size = int(response.headers.get("content-length", 0)) logging.info(f"Total size: {total_size}") response = session.get(url, stream=True) # Raise an error for bad responses # response.raise_for_status() # Get the total file size from headers downloaded_size = 0 # Write the file to the destination folder with ( open(local_filename, "wb") as f, tqdm( desc=local_filename.name, total=total_size, unit="B", unit_scale=True, unit_divisor=1024, disable=not progress_bar, ) as bar, ): for chunk in response.iter_content(chunk_size=8192): f.write(chunk) downloaded_size += len(chunk) bar.update(len(chunk)) if _progress_dict is not None: _progress_dict[_task_id] = { "progress": downloaded_size, "total": total_size, } if downloaded_size == total_size: logging.info("Download complete and successful!") else: logging.error("Download incomplete or corrupted.") return local_filename except Exception as e: logging.error(f"Error downloading {url}: {e}") raise (e) return None
class Dataset(torch.utils.data.Dataset): """Base dataset class with transform support and PyTorch Lightning integration.""" def __init__(self, transform=None): self.transform = transform self._trainer = None def set_pl_trainer(self, trainer: pl.Trainer): self._trainer = trainer def process_sample(self, sample): if self._trainer is not None: if "global_step" in sample: raise ValueError("Can't use that keywords") if "current_epoch" in sample: raise ValueError("Can't use that keywords") sample["global_step"] = self._trainer.global_step sample["current_epoch"] = self._trainer.current_epoch if self.transform: sample = self.transform(sample) return sample def __getitem__(self, idx): raise NotImplementedError def __len__(self): raise NotImplementedError
[docs] class GMM(Dataset): """Gaussian Mixture Model dataset for synthetic data generation.""" def __init__(self, num_components=5, num_samples=100, dim=2): super().__init__() # Define the means for each component means = torch.rand(num_components, dim) * 10 # Define the covariance matrices for each component # For simplicity, we'll use diagonal covariance matrices covariances = torch.stack( [torch.eye(dim) * torch.rand(1) for _ in range(num_components)] ) # Define the mixing coefficients (weights) for each component weights = torch.distributions.Dirichlet(torch.ones(num_components)).sample() # Create a categorical distribution for the mixture components mix = dist.Categorical(weights) # Create a multivariate normal distribution for each component components = dist.MultivariateNormal(means, covariance_matrix=covariances) # Create the Gaussian Mixture Model self.model = dist.MixtureSameFamily(mix, components) self.samples = self.model.sample((num_samples,)) # Calculate the log-likelihoods of all samples self.log_likelihoods = self.model.log_prob(self.samples) def score(self, samples): return self.model.log_prob(samples) def __getitem__(self, idx): sample = dict( sample=self.samples[idx], log_likelihood=self.log_likelihoods[idx] ) return self.process_sample(sample)
[docs] class Subset(Dataset): r"""Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ dataset: Dataset indices: Sequence[int] def __init__(self, dataset: Dataset, indices: Sequence[int]) -> None: super().__init__() self.dataset = dataset self.indices = indices def __getitem__(self, idx): if isinstance(idx, list): return self.dataset[[self.indices[i] for i in idx]] return self.dataset[self.indices[idx]] def __getitems__(self, indices: list[int]) -> list: # add batched sampling support when parent dataset supports it. # see torch.utils.data._utils.fetch._MapDatasetFetcher if callable(getattr(self.dataset, "__getitems__", None)): return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] else: return [self.dataset[self.indices[idx]] for idx in indices] def __len__(self): return len(self.indices) @property def column_names(self): return self.dataset.column_names
[docs] class FromTorchDataset(Dataset): """Wrapper for PyTorch datasets with custom column naming and transforms.""" def __init__(self, dataset, names, transform=None): super().__init__(transform) self.dataset = dataset self.names = names def __getitem__(self, idx): sample = self.dataset[idx] sample = {k: v for k, v in zip(self.names, sample)} return self.process_sample(sample) def __len__(self): return len(self.dataset) @property def column_names(self): return self.names
[docs] class MinariStepsDataset(Dataset): """Dataset for Minari reinforcement learning data with step-based access.""" NAMES = ["observations", "actions", "rewards", "terminations", "truncations"] def __init__(self, dataset, num_steps=2, transform=None): super().__init__(transform) self.num_steps = num_steps self.dataset = dataset self.bounds = self.dataset.episode_indices self.bounds -= np.arange(self.dataset.total_episodes) * (num_steps - 1) self._length = ( self.dataset.total_steps - (num_steps - 1) * self.dataset.total_episodes ) logging.info("Minari Dataset setup") logging.info(f"\t- {self.dataset.total_episodes} episodes") logging.info(f"\t- {len(self)} steps") def nested_step(self, value, idx): if type(value) is dict: return {k: self.nested_step(v, idx) for k, v in value.items()} return value[idx : idx + self.num_steps] def __getitem__(self, idx): ep_idx = np.searchsorted(self.bounds, idx, side="right") - 1 frame_idx = idx - self.bounds[ep_idx] episode = self.dataset[ep_idx] sample = { name: self.nested_step(getattr(episode, name), frame_idx) for name in self.NAMES } return self.process_sample(sample) def __len__(self): return self._length @property def column_names(self): return self.names
class MinariEpisodeDataset(torch.utils.data.Dataset): """Dataset for Minari reinforcement learning data with episode-based access.""" NAMES = ["observations", "actions", "rewards", "terminations", "truncations"] def __init__(self, dataset): self.dataset = dataset self.bounds = self.dataset.episode_indices self._trainer = None logging.info("Minari Dataset setup") logging.info(f"\t- {self.dataset.total_episodes} episodes") logging.info(f"\t- {len(self)} steps") def set_pl_trainer(self, trainer: pl.Trainer): self._trainer = trainer def nested_step(self, value, idx): if type(value) is dict: return {k: self.nested_step(v, idx) for k, v in value.items()} return value[idx] def __getitem__(self, idx): ep_idx = np.searchsorted(self.bounds, idx, side="right") - 1 frame_idx = idx - self.bounds[ep_idx] print(ep_idx, frame_idx) episode = self.dataset[ep_idx] sample = { name: self.nested_step(getattr(episode, name), frame_idx) for name in self.NAMES } if self._trainer is not None: if "global_step" in sample: raise ValueError("Can't use that keywords") if "current_epoch" in sample: raise ValueError("Can't use that keywords") sample["global_step"] = self._trainer.global_step sample["current_epoch"] = self._trainer.current_epoch return sample def __len__(self): return self.dataset.total_steps @property def column_names(self): return self.names
[docs] class HFDataset(Dataset): """Hugging Face dataset wrapper with transform and column manipulation support.""" def __init__( self, *args, transform=None, rename_columns=None, remove_columns=None, **kwargs ): super().__init__(transform) import datasets if ( torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1 ): import time s = int(torch.distributed.get_rank()) * 2 logging.info( f"Sleeping for {s}s to avoid race condition of dataset cache" " see https://github.com/huggingface/transformers/issues/15976)" ) time.sleep(s) if "storage_options" not in kwargs: logging.warning( "You didn't pass a storage optionwe are adding one to avoid timeout" ) from aiohttp import ClientTimeout kwargs["storage_options"] = { "client_kwargs": {"timeout": ClientTimeout(total=3600)} } dataset = datasets.load_dataset(*args, **kwargs) dataset = dataset.add_column("sample_idx", list(range(dataset.num_rows))) if rename_columns is not None: for k, v in rename_columns.items(): dataset = dataset.rename_column(k, v) if remove_columns is not None: dataset = dataset.remove_columns(remove_columns) self.dataset = dataset def __getitem__(self, idx): sample = self.dataset[idx] return self.process_sample(sample) def __len__(self): return self.dataset.num_rows @property def column_names(self): return self.dataset.column_names
[docs] class Categorical(torch.nn.Module): """Categorical distribution for sampling discrete values with given probabilities.""" def __init__( self, values: Union[list, torch.Tensor], probabilities: Union[list, torch.Tensor], ): super().__init__() self.mix = torch.distributions.Categorical(torch.Tensor(probabilities)) self.values = torch.Tensor(values) print(self.mix, self.values) def __call__(self): return self.values[self.mix.sample()] def sample(self, *args, **kwargs): return self.values[self.mix.sample(*args, **kwargs)]
[docs] class ExponentialMixtureNoiseModel(torch.nn.Module): """Exponential mixture noise model for data augmentation or sampling.""" def __init__(self, rates, prior, upper_bound=torch.inf): super().__init__() mix = torch.distributions.Categorical(torch.Tensor(prior)) comp = torch.distributions.Exponential(torch.Tensor(rates)) self.mm = torch.distributions.MixtureSameFamily(mix, comp) self.upper_bound = upper_bound def __call__(self): return self.mm.sample().clip_(min=0, max=self.upper_bound) def sample(self, *args, **kwargs): return self.mm.sample(*args, **kwargs).clip_(min=0, max=self.upper_bound)
[docs] class ExponentialNormalNoiseModel(torch.nn.Module): """Exponential-normal noise model combining exponential and normal distributions.""" def __init__(self, rate, mean, std, prior, upper_bound=torch.inf): super().__init__() self.mix = torch.distributions.Categorical(torch.Tensor(prior)) self.exp = torch.distributions.Exponential(rate) self.gauss = torch.distributions.Normal(mean, std) self.upper_bound = upper_bound def __call__(self): mix = self.mix.sample() if mix == 0: return self.exp.sample().clip_(min=0, max=self.upper_bound) return self.gauss.sample().clip_(min=0, max=self.upper_bound) def sample(self, *args, **kwargs): mix = self.mix.sample(*args, **kwargs) exp = self.exp.sample(*args, **kwargs) gauss = self.gauss.sample(*args, **kwargs) return torch.where(mix.bool(), gauss, exp).clip_(min=0, max=self.upper_bound)
[docs] def fold_views(tensor, idx): _, counts = torch.unique_consecutive(idx, return_counts=True) if not counts.min().eq(counts.max()): raise RuntimeError("counts are not the same for all samples!") n_views = counts[0].item() fold_shape = (tensor.size(0) // n_views, n_views) t = tensor.view(*fold_shape, *tensor.shape[1:]) return t.unbind(dim=1)
[docs] def random_split( dataset: Dataset, lengths: Sequence[Union[int, float]], generator: Optional[Generator] = default_generator, ) -> list[Subset]: r"""Randomly split a dataset into non-overlapping new datasets of given lengths. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. Optionally fix the generator for reproducible results, e.g.: Example: >>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths or fractions of splits to be produced generator (Generator): Generator used for the random permutation. """ if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: subset_lengths: list[int] = [] for i, frac in enumerate(lengths): if frac < 0 or frac > 1: raise ValueError(f"Fraction at index {i} is not between 0 and 1") n_items_in_split = int( math.floor(len(dataset) * frac) # type: ignore[arg-type] ) subset_lengths.append(n_items_in_split) remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] # add 1 to all the lengths in round-robin fashion until the remainder is 0 for i in range(remainder): idx_to_add_at = i % len(subset_lengths) subset_lengths[idx_to_add_at] += 1 lengths = subset_lengths for i, length in enumerate(lengths): if length == 0: warnings.warn( f"Length of split at index {i} is 0. " f"This might result in an empty dataset." ) # Cannot verify that dataset is Sized if sum(lengths) != len(dataset): # type: ignore[arg-type] raise ValueError( "Sum of input lengths does not equal the length of the input dataset!" ) indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] lengths = cast(Sequence[int], lengths) return [ Subset(dataset, indices[offset - length : offset]) for offset, length in zip(itertools.accumulate(lengths), lengths) ]