Quick Start

import stable_worldmodel as swm


world = swm.World(
    "swm/PushT-v1",
    num_envs=5,
    image_shape=(224, 224),
    render_mode="rgb_array",
)

print("Available variations: ", world.single_variation_space.names())

world.set_policy(swm.policy.RandomPolicy())
world.record_dataset(
    "example-pusht",
    episodes=10,
    seed=2347,
    options=None,
)

swm.pretraining(
    "scripts/train/dinowm.py",
    dataset_name="example-pusht",
    output_model_name="dummy_pusht",
    dump_object=True,
)

model = swm.policy.AutoCostModel("dummy_pusht").to("cuda")
config = swm.PlanConfig(horizon=5, receding_horizon=5, action_block=5)
solver = swm.solver.CEMSolver(model, num_samples=300, var_scale=1.0, n_steps=30, topk=30, device="cuda")
policy = swm.policy.WorldModelPolicy(solver=solver, config=config)

world.set_policy(policy)
results = world.evaluate(episodes=5, seed=2347)