Source code for tensorial.reaxkit.utils.utils

from collections.abc import Callable
from importlib.util import find_spec
import logging
from typing import Any
import warnings

import omegaconf

from . import rich_utils

__all__ = "extras", "task_wrapper", "get_metric_value"

_LOGGER = logging.getLogger(__name__)


[docs] def extras(cfg: omegaconf.DictConfig) -> None: """Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing Args: cfg: A DictConfig object containing the config tree. """ # return if no `extras` config if not cfg.get("extras"): _LOGGER.warning("Extras config not found! <cfg.extras=null>") return # disable python warnings if cfg.extras.get("ignore_warnings"): _LOGGER.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config if cfg.extras.get("enforce_tags"): _LOGGER.info("Enforcing tags! <cfg.extras.enforce_tags=True>") rich_utils.enforce_tags(cfg, save_to_file=True) # pretty print config tree using Rich library if cfg.extras.get("print_config"): _LOGGER.info("Printing config tree with Rich! <cfg.extras.print_config=True>") rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
[docs] def task_wrapper(task_func: Callable) -> Callable: """Optional decorator that controls the failure behavior when executing the task function. This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - save the exception to a `.log` file - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - etc. (adjust depending on your needs) Example: .. code-block:: python @utils.task_wrapper def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: # ... return metric_dict, object_dict Args: task_func: The task function to be wrapped. Returns: The wrapped task function. """ def wrap(cfg: omegaconf.DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: # execute the task try: metric_dict, object_dict = task_func(cfg=cfg) # things to do if exception occurs except Exception as ex: # save exception to `.log` file _LOGGER.exception("") # some hyperparameter combinations might be invalid or cause out-of-memory errors # so when using hparam search plugins like Optuna, you might want to disable # raising the below exception to avoid multirun failure raise ex # things to always do after either success or exception finally: # display output dir path in terminal _LOGGER.info("Output dir: %s", cfg.paths.output_dir) # always close wandb run (even if exception occurs so multirun won't fail) if find_spec("wandb"): # check if wandb is installed import wandb # pylint: disable=import-error if wandb.run: _LOGGER.info("Closing wandb!") wandb.finish() return metric_dict, object_dict return wrap
[docs] def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> float | None: """Safely retrieves value of the metric logged in reax.Module. Args: metric_dict: A dict containing metric values. metric_name: If provided, the name of the metric to retrieve. Returns: If a metric name was provided, the value of the metric. """ if not metric_name: _LOGGER.info("Metric name is None! Skipping metric value retrieval...") return None if metric_name not in metric_dict: raise ValueError( f"Metric value not found! <metric_name={metric_name}>\n" "Make sure metric name logged in reax.Module is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" ) metric_value = metric_dict[metric_name].item() _LOGGER.info("Retrieved metric value! <%s=%f>", metric_name, metric_value) return metric_value