stable_worldmodel.wm package
Submodules
stable_worldmodel.wm.dinowm module
- class Attention(dim, heads=8, dim_head=64, dropout=0.0, num_patches=1, num_frames=1)[source]
- Bases: - Module- forward(x)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 - generate_mask_matrix(npatch, nwindow)[source]
 
- class CausalPredictor(*, num_patches, num_frames, dim, depth, heads, mlp_dim, pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0)[source]
- Bases: - Module- forward(x)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class DINOWM(encoder, predictor, action_encoder, proprio_encoder, decoder=None, history_size=3, num_pred=1, device='cpu')[source]
- Bases: - Module- decode(info)[source]
 - encode(info, pixels_key='pixels', target='embed', proprio_key=None, action_key=None)[source]
 - get_cost(info_dict: dict, action_candidates: Tensor)[source]
 - predict(embedding)[source]
- predict next latent state :param embedding: (B, T, P, d) - Returns:
- (B, T, P, d) 
- Return type:
- preds 
 
 - replace_action_in_embedding(embedding, act)[source]
- Replace the action embeddings in the latent state z with the provided actions. 
 - rollout(info, action_sequence)[source]
- Rollout the world model given an initial observation and a sequence of actions. - Params: obs_start: n current observations (B, n, C, H, W) actions: current and predicted actions (B, n+t, action_dim) - Returns: z_obs: dict with latent observations (B, n+t+1, n_patches, D) z: predicted latent states (B, n+t+1, n_patches, D) 
 - split_embedding(embedding, action_dim, proprio_dim)[source]
 
- class Decoder(in_channel, out_channel, channel, n_res_block, n_res_channel, stride)[source]
- Bases: - Module- forward(input)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class Embedder(num_frames=1, tubelet_size=1, in_chans=8, emb_dim=10)[source]
- Bases: - Module- forward(x)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class Encoder(in_channel, channel, n_res_block, n_res_channel, stride)[source]
- Bases: - Module- forward(input)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class FeedForward(dim, hidden_dim, dropout=0.0)[source]
- Bases: - Module- forward(x)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class Quantize(dim, n_embed, decay=0.99, eps=1e-05)[source]
- Bases: - Module- embed_code(embed_id)[source]
 - forward(input)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class ResBlock(in_channel, channel)[source]
- Bases: - Module- forward(input)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0, num_patches=1, num_frames=1)[source]
- Bases: - Module- forward(x)[source]
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class VQVAE(in_channel=3, channel=128, n_res_block=2, n_res_channel=32, emb_dim=64, n_embed=512, decay=0.99, quantize=True)[source]
- Bases: - Module- decode(quant_b)[source]
 - decode_code(code_b)[source]
 - forward(input)[source]
- input: (b, t, num_patches, emb_dim) 
 
- all_reduce(tensor, op=<RedOpType.SUM: 0>)[source]
- get_world_size()[source]
stable_worldmodel.wm.dreamer module
stable_worldmodel.wm.dummy module
- class DummyWorldModel(image_shape, action_dim)[source]
- Bases: - Module- encode(obs)[source]
 - get_cost(info_dict: dict, action_candidates: Tensor)[source]
 - predict(obs, actions, timestep=None)[source]
- Predict next s_t+H embedding given s_t + action sequence i.e rollout the dynamics model for H steps. 
 
- transform(info_dict)[source]
stable_worldmodel.wm.frame module
stable_worldmodel.wm.tdmpc module
Module contents
- class DINOWM(encoder, predictor, action_encoder, proprio_encoder, decoder=None, history_size=3, num_pred=1, device='cpu')[source]
- Bases: - Module- decode(info)[source]
 - encode(info, pixels_key='pixels', target='embed', proprio_key=None, action_key=None)[source]
 - get_cost(info_dict: dict, action_candidates: Tensor)[source]
 - predict(embedding)[source]
- predict next latent state :param embedding: (B, T, P, d) - Returns:
- (B, T, P, d) 
- Return type:
- preds 
 
 - replace_action_in_embedding(embedding, act)[source]
- Replace the action embeddings in the latent state z with the provided actions. 
 - rollout(info, action_sequence)[source]
- Rollout the world model given an initial observation and a sequence of actions. - Params: obs_start: n current observations (B, n, C, H, W) actions: current and predicted actions (B, n+t, action_dim) - Returns: z_obs: dict with latent observations (B, n+t+1, n_patches, D) z: predicted latent states (B, n+t+1, n_patches, D) 
 - split_embedding(embedding, action_dim, proprio_dim)[source]