Source code for tensorial.reaxkit.utils.instantiators

import hydra
import omegaconf
import reax

from . import pylogger

log = pylogger.RankedLogger(__name__, rank_zero_only=True)


__all__ = "instantiate_listeners", "instantiate_loggers"


[docs] def instantiate_listeners( listeners_cfg: omegaconf.DictConfig, ) -> list[reax.TrainerListener]: """Instantiates listeners from config. Args: listeners_cfg: A DictConfig object containing listener configurations. Returns: A list of instantiated listeners. """ listeners: list[reax.TrainerListener] = [] if not listeners_cfg: log.warning("No listener configs found! Skipping..") return listeners if not isinstance(listeners_cfg, omegaconf.DictConfig): raise TypeError("listeners config must be a DictConfig!") for _, cb_conf in listeners_cfg.items(): if isinstance(cb_conf, omegaconf.DictConfig) and "_target_" in cb_conf: log.info( # pylint: disable=protected-access f"Instantiating listener <{cb_conf._target_}>" ) listeners.append(hydra.utils.instantiate(cb_conf)) return listeners
[docs] def instantiate_loggers(logger_cfg: omegaconf.DictConfig) -> list[reax.Logger]: """Instantiates loggers from config. Args: logger_cfg: A DictConfig object containing logger configurations. Returns: A list of instantiated loggers. """ logger: list[reax.Logger] = [] if not logger_cfg: log.warning("No logger configs found! Skipping...") return logger if not isinstance(logger_cfg, omegaconf.DictConfig): raise TypeError("Logger config must be a DictConfig!") for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, omegaconf.DictConfig) and "_target_" in lg_conf: log.info( # pylint: disable=protected-access f"Instantiating logger <{lg_conf._target_}>" ) logger.append(hydra.utils.instantiate(lg_conf)) return logger