Source code for tensorial.reaxkit.utils.utils
from collections.abc import Callable
from importlib.util import find_spec
import logging
from typing import Any
import warnings
import omegaconf
from . import rich_utils
__all__ = "extras", "task_wrapper", "get_metric_value"
_LOGGER = logging.getLogger(__name__)
[docs]
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents
multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and
rerun it later)
- etc. (adjust depending on your needs)
Example:
.. code-block:: python
@utils.task_wrapper
def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
# ...
return metric_dict, object_dict
Args:
task_func: The task function to be wrapped.
Returns:
The wrapped task function.
"""
def wrap(cfg: omegaconf.DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
_LOGGER.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
_LOGGER.info("Output dir: %s", cfg.paths.output_dir)
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb # pylint: disable=import-error
if wandb.run:
_LOGGER.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
[docs]
def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> float | None:
"""Safely retrieves value of the metric logged in reax.Module.
Args:
metric_dict: A dict containing metric values.
metric_name: If provided, the name of the metric to retrieve.
Returns:
If a metric name was provided, the value of the metric.
"""
if not metric_name:
_LOGGER.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise ValueError(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in reax.Module is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
_LOGGER.info("Retrieved metric value! <%s=%f>", metric_name, metric_value)
return metric_value