Source code for tensorial.reaxkit.utils.rich_utils

from collections.abc import Mapping, Sequence
import logging
from pathlib import Path
from typing import Final

from hydra.core.hydra_config import HydraConfig
from lightning_utilities.core import rank_zero
import omegaconf
import rich
from rich.prompt import Prompt
import rich.syntax
import rich.tree

_LOGGER = logging.getLogger(__name__)

TREE_STYLE: Final[str] = "dim"










[docs] @rank_zero.rank_zero_only def enforce_tags(cfg: omegaconf.DictConfig, save_to_file: bool = False) -> None: """Prompts user to input tags from command line if no tags are provided in config. Args: cfg: A DictConfig composed by Hydra. save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. """ if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") _LOGGER.warning("No tags provided in config. Prompting user to input tags...") tags = Prompt.ask("Enter a list of comma separated tags", default="dev") tags = [t.strip() for t in tags.split(",") if t != ""] with omegaconf.open_dict(cfg): cfg.tags = tags _LOGGER.info("Tags: %s", cfg.tags) if save_to_file: with open(Path(cfg.paths.output_dir, "tags.log"), "w", encoding="utf-8") as file: rich.print(cfg.tags, file=file)