SupervisedBatchSampler

SupervisedBatchSampler#

class stable_ssl.data.SupervisedBatchSampler(batch_size: int, n_views: int, targets_or_dataset: Dataset | list, *args, **kwargs)[source]#

Bases: Sampler[List[int]]

Wraps another sampler to yield a mini-batch of indices.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]