FromTorchDataset

Contents

FromTorchDataset#

class stable_pretraining.data.FromTorchDataset(dataset, names, transform=None, add_sample_idx=True)[source]#

Bases: Dataset

Wrapper for PyTorch datasets with custom column naming and transforms.

Parameters:
  • dataset – PyTorch dataset to wrap

  • names – List of names for each element returned by the dataset

  • transform – Optional transform to apply to samples

  • add_sample_idx – If True, automatically adds ‘sample_idx’ field to each sample