Source code for stable_worldmodel.utils

"""Utility functions for stable_worldmodel."""

import os
import shlex
import subprocess
import sys
from collections.abc import Iterable
from typing import Any

from loguru import logger as logging


[docs] def pretraining( script_path: str, dataset_name: str, output_model_name: str, dump_object: bool = True, args: str = "", ) -> int: """Run a pretraining script as a subprocess with optional command-line arguments. This function checks if the specified script exists, constructs a command to run it with the provided arguments, and executes the command in a subprocess. Args: script_path (str): The path to the pretraining script to be executed. dataset_name (str): The name of the dataset to be used in pretraining. output_model_name (str): The name to save the output model. dump_object (bool, optional): Whether to dump the model object after training. Defaults to args (str, optional): A string of command-line arguments to pass to the script. Defaults to an empty string. Returns: int: The return code of the subprocess. A return code of 0 indicates success. Raises: ValueError: If the specified script does not exist. SystemExit: If the subprocess exits with a non-zero return code. """ if not os.path.isfile(script_path): raise ValueError(f"Script {script_path} does not exist.") logging.info(f"🏃🏃🏃 Running pretraining script: {script_path} with args: {args} 🏃🏃🏃") env = os.environ.copy() env.setdefault("PYTHONUNBUFFERED", "1") args = f"{args} ++dump_object={dump_object} dataset_name={dataset_name} output_model_name={output_model_name}" cmd = [sys.executable, script_path] + shlex.split(args) try: subprocess.run(cmd, env=env, check=True) except subprocess.CalledProcessError as e: sys.exit(e.returncode) logging.info("🏁🏁🏁 Pretraining script finished 🏁🏁🏁") return
[docs] def flatten_dict(d, parent_key="", sep="."): """Flatten a nested dictionary into a single-level dictionary with concatenated keys. The naming convention for the new keys is similar to Hydra's, using a `.` separator to denote levels of nesting. Attention is needed when flattening dictionaries with overlapping keys, as this may lead to information loss. Args: d (dict): The nested dictionary to flatten. parent_key (str, optional): The base key to use for the flattened keys. sep (str, optional): The separator to use between levels of nesting. Defaults to '.'. Returns: dict: A flattened version of the input dictionary. Examples: >>> info = {"a": {"b": {"c": 42, "d": 43}}, "e": 44} >>> flatten_dict(info) {'a.b.c': 42, 'a.b.d': 43, 'e': 44} >>> flatten_dict({"a": {"b": 2}, "a.b": 3}) {'a.b': 3} """ items = {} for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.update(flatten_dict(v, new_key, sep=sep)) else: items[new_key] = v return items
[docs] def get_in(mapping: dict, path: Iterable[str]) -> Any: """Retrieve a value from a nested dictionary using a sequence of keys. Args: mapping (dict): A nested dictionary. path (Iterable[str]): An iterable of keys representing the path to the desired value in mapping. Returns: Any: The value located at the specified path in the nested dictionary. Raises: KeyError: If any key in the path does not exist in the mapping dict. Examples: >>> variations = {"a": {"b": {"c": 42}}} >>> get_in(variations, ["a", "b", "c"]) 42 """ cur = mapping for key in list(path): cur = cur[key] return cur