Source code for stable_ssl.data.sampler

import math
from typing import Iterable, Iterator, List, Union

import numpy as np
import torch
import torch.distributed as dist


def _check_integer(v, name):
    if not isinstance(v, int) or isinstance(v, bool) or v <= 0:
        raise ValueError(
            f"{name} should be a positive integer value, but got {name}={v}"
        )


[docs] class RepeatedRandomSampler(torch.utils.data.DistributedSampler): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Args: data_source (Dataset): dataset to sample from n_views (int): number of views to repeat each sample, default=1 replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. generator (Generator): Generator used in sampling. """ def __init__( self, data_source_or_len: Union[int, Iterable], n_views: int = 1, replacement: bool = False, seed: int = 0, ): if type(data_source_or_len) is int: self._data_source_len = data_source_or_len else: self._data_source_len = len(data_source_or_len) self.replacement = replacement self.n_views = n_views self.seed = seed self.epoch = 0 if dist.is_available() and dist.is_initialized(): self.num_replicas = dist.get_world_size() self.rank = dist.get_rank() if self.rank >= self.num_replicas or self.rank < 0: raise ValueError( f"Invalid rank {self.rank}, rank should be in the interval [0, {self.num_replicas - 1}]" ) else: self.num_replicas = 1 self.rank = 0 if self._data_source_len % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( (self._data_source_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = self._data_source_len // self.num_replicas # type: ignore[arg-type] if not isinstance(self.replacement, bool): raise TypeError( f"replacement should be a boolean value, but got replacement={self.replacement}" ) def __len__(self): return self.num_samples * self.n_views # @property # def num_samples(self) -> int: # return (len(self.data_source) * self.n_views) // self.num_replicas def __iter__(self) -> Iterator[int]: n = self._data_source_len g = torch.Generator() g.manual_seed(self.seed + self.epoch) if self.replacement: raise NotImplementedError() for _ in range(self.num_samples // 32): yield from torch.randint( high=n, size=(32,), dtype=torch.int64, generator=g ).tolist() yield from torch.randint( high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=g, ).tolist() else: overall_slice = torch.randperm(n, generator=g) rank_slice = overall_slice[ self.rank * self.num_samples : (self.rank + 1) * self.num_samples ] yield from rank_slice.repeat_interleave(self.n_views).tolist()
# class RepeatedSampler(torch.utils.data.BatchSampler): # def __init__( # self, # num_samples_or_dataset: Union[torch.utils.data.Dataset, HFDataset, int], # batch_size: Optional[int] = 1, # n_views: Optional[int] = 1, # shuffle: bool = True, # replacement: bool = False, # sampler=None, # drop_last: bool = False, # ) -> None: # # if hasattr(num_samples_or_dataset, "__len__"): # # num_samples_or_dataset = len(num_samples_or_dataset) # if sampler is None and shuffle: # sampler = torch.utils.data.RandomSampler( # num_samples_or_dataset, replacement=replacement # ) # elif sampler is None and not shuffle: # sampler = torch.utils.data.SequentialSampler(num_samples_or_dataset) # super().__init__(sampler=sampler, batch_size=batch_size, drop_last=drop_last) # _check_integer(n_views, "n_views") # self.batch_size = batch_size # self.n_views = n_views # def __iter__(self) -> Iterator[List[int]]: # # Create multiple references to the same iterator # sampler_iter = iter(self.sampler) # args = [sampler_iter] * (self.batch_size // self.n_views) # for batch_droplast in zip(*args): # indices = [*batch_droplast] # yield [int(i) for i in np.repeat(indices, self.n_views)] # def __len__(self) -> int: # # Can only be called if self.sampler has __len__ implemented # # We cannot enforce this condition, so we turn off typechecking for the # # implementation below. # return len(self.sampler) // (self.batch_size // self.n_views)
[docs] class SupervisedBatchSampler(torch.utils.data.Sampler[List[int]]): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False ... ) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__( self, batch_size: int, n_views: int, targets_or_dataset: Union[torch.utils.data.Dataset, list], *args, **kwargs, ) -> None: if ( not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0 ): raise ValueError( f"batch_size should be a positive integer value, but got batch_size={batch_size}" ) if not isinstance(n_views, int) or isinstance(n_views, bool) or n_views <= 0: raise ValueError( f"n_views should be a positive integer value, but got n_views={n_views}" ) self.batch_size = batch_size self.n_views = n_views if isinstance(targets_or_dataset, torch.utils.data.Dataset): targets = targets_or_dataset.targets else: targets = targets_or_dataset self._length = len(targets) self.batches = {} unique_targets, counts = np.unique(targets, return_counts=True) self.prior = counts / counts.sum() for label in unique_targets: self.batches[label.item()] = np.flatnonzero(targets == label) def __iter__(self) -> Iterator[List[int]]: for _ in range(len(self)): n_parents = self.batch_size // self.n_views parents = np.random.choice( list(self.batches.keys()), size=n_parents, replace=True, p=self.prior ) indices = [] for p in parents: indices.extend( np.random.choice(self.batches[p], size=self.n_views, replace=False) ) indices = np.asarray(indices).astype(int) yield indices def __len__(self) -> int: # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. return self._length // self.batch_size // self.n_views
[docs] class RandomBatchSampler(torch.utils.data.Sampler[List[int]]): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False ... ) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__( self, batch_size: int, length_or_dataset: Union[torch.utils.data.Dataset, int], *args, **kwargs, ) -> None: if ( not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0 ): raise ValueError( f"batch_size should be a positive integer value, but got batch_size={batch_size}" ) self.batch_size = batch_size if isinstance(length_or_dataset, torch.utils.data.Dataset): length_or_dataset = len(length_or_dataset) self._length = length_or_dataset def __iter__(self) -> Iterator[List[int]]: perm = np.random.permutation(self._length).astype(int) for i in range(len(self)): yield perm[i * self.batch_size : (i + 1) * self.batch_size] def __len__(self) -> int: # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. return len(self.sampler) // self.batch_size // self.n_views