stable_worldmodel.wm package
Submodules
stable_worldmodel.wm.dinowm module
- class Attention(dim, heads=8, dim_head=64, dropout=0.0, num_patches=1, num_frames=1)[source]
Bases:
Module- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- generate_mask_matrix(npatch, nwindow)[source]
- class CausalPredictor(*, num_patches, num_frames, dim, depth, heads, mlp_dim, pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0)[source]
Bases:
Module- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class DINOWM(encoder, predictor, action_encoder, proprio_encoder, decoder=None, history_size=3, num_pred=1, device='cpu')[source]
Bases:
Module- decode(info)[source]
- encode(info, pixels_key='pixels', target='embed', proprio_key=None, action_key=None)[source]
- get_cost(info_dict: dict, action_candidates: Tensor)[source]
- predict(embedding)[source]
predict next latent state :param embedding: (B, T, P, d)
- Returns:
(B, T, P, d)
- Return type:
preds
- replace_action_in_embedding(embedding, act)[source]
Replace the action embeddings in the latent state z with the provided actions.
- rollout(info, action_sequence)[source]
Rollout the world model given an initial observation and a sequence of actions.
Params: obs_start: n current observations (B, n, C, H, W) actions: current and predicted actions (B, n+t, action_dim)
Returns: z_obs: dict with latent observations (B, n+t+1, n_patches, D) z: predicted latent states (B, n+t+1, n_patches, D)
- split_embedding(embedding, action_dim, proprio_dim)[source]
- class Decoder(in_channel, out_channel, channel, n_res_block, n_res_channel, stride)[source]
Bases:
Module- forward(input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class Embedder(num_frames=1, tubelet_size=1, in_chans=8, emb_dim=10)[source]
Bases:
Module- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class Encoder(in_channel, channel, n_res_block, n_res_channel, stride)[source]
Bases:
Module- forward(input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class FeedForward(dim, hidden_dim, dropout=0.0)[source]
Bases:
Module- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class Quantize(dim, n_embed, decay=0.99, eps=1e-05)[source]
Bases:
Module- embed_code(embed_id)[source]
- forward(input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ResBlock(in_channel, channel)[source]
Bases:
Module- forward(input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0, num_patches=1, num_frames=1)[source]
Bases:
Module- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class VQVAE(in_channel=3, channel=128, n_res_block=2, n_res_channel=32, emb_dim=64, n_embed=512, decay=0.99, quantize=True)[source]
Bases:
Module- decode(quant_b)[source]
- decode_code(code_b)[source]
- forward(input)[source]
input: (b, t, num_patches, emb_dim)
- all_reduce(tensor, op=<RedOpType.SUM: 0>)[source]
- get_world_size()[source]
stable_worldmodel.wm.dreamer module
stable_worldmodel.wm.dummy module
- class DummyWorldModel(image_shape, action_dim)[source]
Bases:
Module- encode(obs)[source]
- get_cost(info_dict: dict, action_candidates: Tensor)[source]
- predict(obs, actions, timestep=None)[source]
Predict next s_t+H embedding given s_t + action sequence i.e rollout the dynamics model for H steps.
- transform(info_dict)[source]
stable_worldmodel.wm.frame module
stable_worldmodel.wm.tdmpc module
Module contents
- class DINOWM(encoder, predictor, action_encoder, proprio_encoder, decoder=None, history_size=3, num_pred=1, device='cpu')[source]
Bases:
Module- decode(info)[source]
- encode(info, pixels_key='pixels', target='embed', proprio_key=None, action_key=None)[source]
- get_cost(info_dict: dict, action_candidates: Tensor)[source]
- predict(embedding)[source]
predict next latent state :param embedding: (B, T, P, d)
- Returns:
(B, T, P, d)
- Return type:
preds
- replace_action_in_embedding(embedding, act)[source]
Replace the action embeddings in the latent state z with the provided actions.
- rollout(info, action_sequence)[source]
Rollout the world model given an initial observation and a sequence of actions.
Params: obs_start: n current observations (B, n, C, H, W) actions: current and predicted actions (B, n+t, action_dim)
Returns: z_obs: dict with latent observations (B, n+t+1, n_patches, D) z: predicted latent states (B, n+t+1, n_patches, D)
- split_embedding(embedding, action_dim, proprio_dim)[source]