random_split#
- stable_ssl.data.random_split(dataset: ~stable_ssl.data.utils.Dataset, lengths: ~collections.abc.Sequence[int | float], generator: ~torch._C.Generator | None = <torch._C.Generator object>) list[Subset] [source]#
Randomly split a dataset into non-overlapping new datasets of given lengths.
If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.
Optionally fix the generator for reproducible results, e.g.:
Example
>>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- Parameters:
dataset (Dataset) – Dataset to be split
lengths (sequence) – lengths or fractions of splits to be produced
generator (Generator) – Generator used for the random permutation.