swiss_roll

Contents

swiss_roll#

stable_pretraining.data.swiss_roll(N, margin=1, sampler_time=Uniform(low: 0.10000000149011612, high: 3.0), sampler_width=Uniform(low: 0.0, high: 1.0))[source]#

Generate Swiss Roll dataset points.

Parameters:
  • N – Number of points to generate

  • margin – Margin parameter for the roll

  • sampler_time – Distribution for sampling time parameter

  • sampler_width – Distribution for sampling width parameter

Returns:

Tensor of shape (N, 3) containing Swiss Roll points