Source code for stable_ssl.data
"""Data utilities for stable-ssl."""
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
# Randall Balestriero <randallbalestriero@gmail.com>
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Iterable, Union
import torch
from datasets import load_dataset
from typing_extensions import override
class _DatasetSamplerWrapper(torch.utils.data.Dataset):
"""Dataset to create indexes from `Sampler` or `Iterable`."""
def __init__(self, sampler) -> None:
self._sampler = sampler
# defer materializing an iterator until it is necessary
self._sampler_list = None
@override
def __getitem__(self, index: int):
if self._sampler_list is None:
self._sampler_list = list(self._sampler)
return self._sampler_list[index]
def __len__(self) -> int:
return len(self._sampler)
def reset(self) -> None:
"""Reset the sampler list in order to get new sampling."""
self._sampler_list = list(self._sampler)
[docs]
class DistributedSamplerWrapper(torch.utils.data.DistributedSampler):
"""Wrap a dataloader for DDP.
Parameters
----------
sampler: iterable
The original dataset sampler.
"""
def __init__(self, sampler, *args, **kwargs) -> None:
super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs)
@override
def __iter__(self) -> Iterable:
"""Iterate over DDP dataset.
Returns
-------
Iterable: minibatch
"""
self.dataset.reset()
return (self.dataset[index] for index in super().__iter__())
[docs]
class MultiViewSampler:
"""Apply a list of transforms to an input and return all outputs."""
def __init__(self, transforms: list):
logging.info(f"MultiViewSampler initialized with {len(transforms)} views.")
self.transforms = transforms
def __call__(self, x):
views = []
for t in self.transforms:
views.append(t(x))
if len(self.transforms) == 1:
return views[0]
return views
[docs]
class HuggingFaceDataset(torch.utils.data.Dataset):
"""Load a HuggingFace dataset.
Parameters
----------
*args: list
Additional arguments to pass to `datasets.load_dataset`.
rename_columns: dict
A mapping of names from the HF dataset to what the dict should contain in this dataset.
For example `{"x":"image", "y":"label"}
remove_columns: list
A mapping of names from the HF dataset to what the dict should contain in this dataset.
For example `{"x":"image", "y":"label"}
transform: dict[str: callable]
Which key to transform
add_index: bool
Whether to add a key "index" with the datum index
**kwargs: dict
Additional keyword arguments to pass to `datasets.load_dataset`.
"""
def __init__(
self,
*args: list,
rename_columns: dict = None,
remove_columns: dict = None,
transform: dict = None,
add_index: bool = False,
**kwargs: dict,
):
self.add_index = add_index
self.transform = transform or {}
dataset = load_dataset(*args, **kwargs)
if remove_columns is not None:
dataset = dataset.remove_columns(remove_columns)
if rename_columns is not None:
dataset = dataset.rename_columns(rename_columns)
self.dataset = dataset
def __len__(self) -> int:
"""Get the length of the dataset."""
return len(self.dataset)
def __getitem__(self, idx: Union[int, torch.Tensor]) -> tuple:
"""Get a sample from the dataset.
Parameters
----------
idx: int or torch.Tensor
Index to sample from the dataset.
Returns
-------
dict: (str, data)
A dict containing the data sample.
"""
if isinstance(idx, torch.Tensor) and idx.dim() == 0:
idx = idx.item()
idx = int(idx)
sample = self.dataset[idx]
for k, t in self.transform.items():
sample[k] = t(sample[k])
assert type(sample) is dict
if self.add_index:
if "index" in sample:
raise ValueError("Tried to add index in data but already present")
sample["index"] = idx
return sample