Source code for stable_pretraining.data.datasets
"""Dataset classes for real data sources.
This module provides dataset wrappers and utilities for working with real data sources
including PyTorch datasets, HuggingFace datasets, and dataset subsets.
"""
import time
from collections.abc import Sequence
import lightning as pl
import torch
from loguru import logger as logging
[docs]
class Dataset(torch.utils.data.Dataset):
"""Base dataset class with transform support and PyTorch Lightning integration."""
def __init__(self, transform=None):
self.transform = transform
self._trainer = None
def set_pl_trainer(self, trainer: pl.Trainer):
self._trainer = trainer
def process_sample(self, sample):
if self._trainer is not None:
if "global_step" in sample:
raise ValueError("Can't use that keywords")
if "current_epoch" in sample:
raise ValueError("Can't use that keywords")
sample["global_step"] = self._trainer.global_step
sample["current_epoch"] = self._trainer.current_epoch
if self.transform:
sample = self.transform(sample)
return sample
def __getitem__(self, idx):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
[docs]
class Subset(Dataset):
r"""Subset of a dataset at specified indices.
Args:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
dataset: Dataset
indices: Sequence[int]
def __init__(self, dataset: Dataset, indices: Sequence[int]) -> None:
super().__init__()
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
if isinstance(idx, list):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]
def __getitems__(self, indices: list[int]) -> list:
# add batched sampling support when parent dataset supports it.
# see torch.utils.data._utils.fetch._MapDatasetFetcher
if callable(getattr(self.dataset, "__getitems__", None)):
return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
else:
return [self.dataset[self.indices[idx]] for idx in indices]
def __len__(self):
return len(self.indices)
@property
def column_names(self):
return self.dataset.column_names
[docs]
class FromTorchDataset(Dataset):
"""Wrapper for PyTorch datasets with custom column naming and transforms.
Args:
dataset: PyTorch dataset to wrap
names: List of names for each element returned by the dataset
transform: Optional transform to apply to samples
add_sample_idx: If True, automatically adds 'sample_idx' field to each sample
"""
def __init__(self, dataset, names, transform=None, add_sample_idx=True):
super().__init__(transform)
self.dataset = dataset
self.names = names
self.add_sample_idx = add_sample_idx
def __getitem__(self, idx):
sample = self.dataset[idx]
sample = {k: v for k, v in zip(self.names, sample)}
if self.add_sample_idx:
sample["sample_idx"] = idx
return self.process_sample(sample)
def __len__(self):
return len(self.dataset)
@property
def column_names(self):
columns = list(self.names)
if self.add_sample_idx and "sample_idx" not in columns:
columns.append("sample_idx")
return columns
[docs]
class HFDataset(Dataset):
"""Hugging Face dataset wrapper with transform and column manipulation support."""
def __init__(
self, *args, transform=None, rename_columns=None, remove_columns=None, **kwargs
):
super().__init__(transform)
import datasets
if (
torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
):
s = int(torch.distributed.get_rank()) * 2
logging.info(
f"Sleeping for {s}s to avoid race condition of dataset cache"
" see https://github.com/huggingface/transformers/issues/15976)"
)
time.sleep(s)
if "storage_options" not in kwargs:
logging.warning(
"You didn't pass a storage optionwe are adding one to avoid timeout"
)
from aiohttp import ClientTimeout
kwargs["storage_options"] = {
"client_kwargs": {"timeout": ClientTimeout(total=3600)}
}
dataset = datasets.load_dataset(*args, **kwargs)
dataset = dataset.add_column("sample_idx", list(range(dataset.num_rows)))
if rename_columns is not None:
for k, v in rename_columns.items():
dataset = dataset.rename_column(k, v)
if remove_columns is not None:
dataset = dataset.remove_columns(remove_columns)
self.dataset = dataset
def __getitem__(self, idx):
sample = self.dataset[idx]
return self.process_sample(sample)
def __len__(self):
return self.dataset.num_rows
@property
def column_names(self):
return self.dataset.column_names