Source code for tensorial.reaxkit.utils.instantiators

import logging

import hydra
import omegaconf
import reax

_LOGGER = logging.getLogger(__name__)


__all__ = "instantiate_listeners", "instantiate_loggers"


[docs] def instantiate_listeners( listeners_cfg: omegaconf.DictConfig, ) -> list[reax.TrainerListener]: """Instantiates listeners from config. Args: listeners_cfg: A DictConfig object containing listener configurations. Returns: A list of instantiated listeners. """ listeners: list[reax.TrainerListener] = [] if not listeners_cfg: _LOGGER.warning("No listener configs found! Skipping..") return listeners if not isinstance(listeners_cfg, omegaconf.DictConfig): raise TypeError("listeners config must be a DictConfig!") for _, cb_conf in listeners_cfg.items(): if isinstance(cb_conf, omegaconf.DictConfig) and "_target_" in cb_conf: _LOGGER.info( "Instantiating listener <%s>", cb_conf._target_, # pylint: disable=protected-access ) listeners.append(hydra.utils.instantiate(cb_conf)) return listeners
[docs] def instantiate_loggers(logger_cfg: omegaconf.DictConfig) -> list[reax.Logger]: """Instantiates loggers from config. Args: logger_cfg: A DictConfig object containing logger configurations. Returns: A list of instantiated loggers. """ logger: list[reax.Logger] = [] if not logger_cfg: _LOGGER.warning("No logger configs found! Skipping...") return logger if not isinstance(logger_cfg, omegaconf.DictConfig): raise TypeError("Logger config must be a DictConfig!") for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, omegaconf.DictConfig) and "_target_" in lg_conf: _LOGGER.info( "Instantiating logger <%s>", lg_conf._target_, # pylint: disable=protected-access ) logger.append(hydra.utils.instantiate(lg_conf)) return logger