Source code for stable_worldmodel.utils
"""Utility functions for stable_worldmodel."""
import os
import shlex
import subprocess
import sys
from collections.abc import Iterable
from typing import Any
from loguru import logger as logging
[docs]
def pretraining(
script_path: str,
dataset_name: str,
output_model_name: str,
dump_object: bool = True,
args: str = "",
) -> int:
"""Run a pretraining script as a subprocess with optional command-line arguments.
This function checks if the specified script exists, constructs a command to run it with the provided arguments,
and executes the command in a subprocess.
Args:
script_path (str): The path to the pretraining script to be executed.
dataset_name (str): The name of the dataset to be used in pretraining.
output_model_name (str): The name to save the output model.
dump_object (bool, optional): Whether to dump the model object after training. Defaults to
args (str, optional): A string of command-line arguments to pass to the script. Defaults to an empty string.
Returns:
int: The return code of the subprocess. A return code of 0 indicates success.
Raises:
ValueError: If the specified script does not exist.
SystemExit: If the subprocess exits with a non-zero return code.
"""
if not os.path.isfile(script_path):
raise ValueError(f"Script {script_path} does not exist.")
logging.info(f"🏃🏃🏃 Running pretraining script: {script_path} with args: {args} 🏃🏃🏃")
env = os.environ.copy()
env.setdefault("PYTHONUNBUFFERED", "1")
args = f"{args} ++dump_object={dump_object} dataset_name={dataset_name} output_model_name={output_model_name}"
cmd = [sys.executable, script_path] + shlex.split(args)
try:
subprocess.run(cmd, env=env, check=True)
except subprocess.CalledProcessError as e:
sys.exit(e.returncode)
logging.info("🏁🏁🏁 Pretraining script finished 🏁🏁🏁")
return
[docs]
def flatten_dict(d, parent_key="", sep="."):
"""Flatten a nested dictionary into a single-level dictionary with concatenated keys.
The naming convention for the new keys is similar to Hydra's, using a `.` separator to denote levels of nesting.
Attention is needed when flattening dictionaries with overlapping keys, as this may lead to information loss.
Args:
d (dict): The nested dictionary to flatten.
parent_key (str, optional): The base key to use for the flattened keys.
sep (str, optional): The separator to use between levels of nesting. Defaults to '.'.
Returns:
dict: A flattened version of the input dictionary.
Examples:
>>> info = {"a": {"b": {"c": 42, "d": 43}}, "e": 44}
>>> flatten_dict(info)
{'a.b.c': 42, 'a.b.d': 43, 'e': 44}
>>> flatten_dict({"a": {"b": 2}, "a.b": 3})
{'a.b': 3}
"""
items = {}
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(flatten_dict(v, new_key, sep=sep))
else:
items[new_key] = v
return items
[docs]
def get_in(mapping: dict, path: Iterable[str]) -> Any:
"""Retrieve a value from a nested dictionary using a sequence of keys.
Args:
mapping (dict): A nested dictionary.
path (Iterable[str]): An iterable of keys representing the path to the desired value in mapping.
Returns:
Any: The value located at the specified path in the nested dictionary.
Raises:
KeyError: If any key in the path does not exist in the mapping dict.
Examples:
>>> variations = {"a": {"b": {"c": 42}}}
>>> get_in(variations, ["a", "b", "c"])
42
"""
cur = mapping
for key in list(path):
cur = cur[key]
return cur