random_split

Contents

random_split#

stable_ssl.data.random_split(dataset: ~stable_ssl.data.utils.Dataset, lengths: ~collections.abc.Sequence[int | float], generator: ~torch._C.Generator | None = <torch._C.Generator object>) list[Subset][source]#

Randomly split a dataset into non-overlapping new datasets of given lengths.

If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.

After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.

Optionally fix the generator for reproducible results, e.g.:

Example

>>> # xdoctest: +SKIP
>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
Parameters:
  • dataset (Dataset) – Dataset to be split

  • lengths (sequence) – lengths or fractions of splits to be produced

  • generator (Generator) – Generator used for the random permutation.