RandomBatchSampler

RandomBatchSampler#

class stable_ssl.data.RandomBatchSampler(batch_size: int, length_or_dataset: Dataset | int, *args, **kwargs)[source]#

Bases: Sampler[List[int]]

Wraps another sampler to yield a mini-batch of indices.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]