swiss_roll#
- stable_pretraining.data.swiss_roll(N, margin=1, sampler_time=Uniform(low: 0.10000000149011612, high: 3.0), sampler_width=Uniform(low: 0.0, high: 1.0))[source]#
Generate Swiss Roll dataset points.
- Parameters:
N – Number of points to generate
margin – Margin parameter for the roll
sampler_time – Distribution for sampling time parameter
sampler_width – Distribution for sampling width parameter
- Returns:
Tensor of shape (N, 3) containing Swiss Roll points