stable_worldmodel.wm package

Submodules

stable_worldmodel.wm.dinowm module

class Attention(dim, heads=8, dim_head=64, dropout=0.0, num_patches=1, num_frames=1)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

generate_mask_matrix(npatch, nwindow)[source]
class CausalPredictor(*, num_patches, num_frames, dim, depth, heads, mlp_dim, pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DINOWM(encoder, predictor, action_encoder, proprio_encoder, decoder=None, history_size=3, num_pred=1, device='cpu')[source]

Bases: Module

decode(info)[source]
encode(info, pixels_key='pixels', target='embed', proprio_key=None, action_key=None)[source]
get_cost(info_dict: dict, action_candidates: Tensor)[source]
predict(embedding)[source]

predict next latent state :param embedding: (B, T, P, d)

Returns:

(B, T, P, d)

Return type:

preds

replace_action_in_embedding(embedding, act)[source]

Replace the action embeddings in the latent state z with the provided actions.

rollout(info, action_sequence)[source]

Rollout the world model given an initial observation and a sequence of actions.

Params: obs_start: n current observations (B, n, C, H, W) actions: current and predicted actions (B, n+t, action_dim)

Returns: z_obs: dict with latent observations (B, n+t+1, n_patches, D) z: predicted latent states (B, n+t+1, n_patches, D)

split_embedding(embedding, action_dim, proprio_dim)[source]
class Decoder(in_channel, out_channel, channel, n_res_block, n_res_channel, stride)[source]

Bases: Module

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Embedder(num_frames=1, tubelet_size=1, in_chans=8, emb_dim=10)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Encoder(in_channel, channel, n_res_block, n_res_channel, stride)[source]

Bases: Module

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class FeedForward(dim, hidden_dim, dropout=0.0)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Quantize(dim, n_embed, decay=0.99, eps=1e-05)[source]

Bases: Module

embed_code(embed_id)[source]
forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ResBlock(in_channel, channel)[source]

Bases: Module

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=0.0, num_patches=1, num_frames=1)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class VQVAE(in_channel=3, channel=128, n_res_block=2, n_res_channel=32, emb_dim=64, n_embed=512, decay=0.99, quantize=True)[source]

Bases: Module

decode(quant_b)[source]
decode_code(code_b)[source]
forward(input)[source]

input: (b, t, num_patches, emb_dim)

all_reduce(tensor, op=<RedOpType.SUM: 0>)[source]
get_world_size()[source]

stable_worldmodel.wm.dreamer module

stable_worldmodel.wm.dummy module

class DummyWorldModel(image_shape, action_dim)[source]

Bases: Module

encode(obs)[source]
get_cost(info_dict: dict, action_candidates: Tensor)[source]
predict(obs, actions, timestep=None)[source]

Predict next s_t+H embedding given s_t + action sequence i.e rollout the dynamics model for H steps.

transform(info_dict)[source]

stable_worldmodel.wm.frame module

stable_worldmodel.wm.tdmpc module

Module contents

class DINOWM(encoder, predictor, action_encoder, proprio_encoder, decoder=None, history_size=3, num_pred=1, device='cpu')[source]

Bases: Module

decode(info)[source]
encode(info, pixels_key='pixels', target='embed', proprio_key=None, action_key=None)[source]
get_cost(info_dict: dict, action_candidates: Tensor)[source]
predict(embedding)[source]

predict next latent state :param embedding: (B, T, P, d)

Returns:

(B, T, P, d)

Return type:

preds

replace_action_in_embedding(embedding, act)[source]

Replace the action embeddings in the latent state z with the provided actions.

rollout(info, action_sequence)[source]

Rollout the world model given an initial observation and a sequence of actions.

Params: obs_start: n current observations (B, n, C, H, W) actions: current and predicted actions (B, n+t, action_dim)

Returns: z_obs: dict with latent observations (B, n+t+1, n_patches, D) z: predicted latent states (B, n+t+1, n_patches, D)

split_embedding(embedding, action_dim, proprio_dim)[source]
class DummyWorldModel(image_shape, action_dim)[source]

Bases: Module

encode(obs)[source]
get_cost(info_dict: dict, action_candidates: Tensor)[source]
predict(obs, actions, timestep=None)[source]

Predict next s_t+H embedding given s_t + action sequence i.e rollout the dynamics model for H steps.