Source code for tensorial.reaxkit.train

import pathlib

import hydra
from hydra.core import hydra_config
import omegaconf
import reax.utils

from . import config, from_data, keys, utils
from .utils import pylogger

_LOGGER = pylogger.RankedLogger(__name__, rank_zero_only=True)

DEFAULT_TRAIN_FILE = "train.yaml"


[docs] def train(cfg: omegaconf.DictConfig | dict): if isinstance(cfg, dict): cfg = omegaconf.DictConfig(cfg) try: output_dir = pathlib.Path(hydra_config.HydraConfig.get().runtime.output_dir) except ValueError: output_dir = None # set seed for random number generators in JAX, numpy and python.random if cfg.get("seed"): reax.seed_everything(cfg.seed, workers=True) _LOGGER.info("Instantiating listeners...") listeners: list[reax.TrainerListener] = utils.instantiate_listeners(cfg.get("listeners")) _LOGGER.info("Instantiating loggers...") logger: list[reax.Logger] = utils.instantiate_loggers(cfg.get("logger")) _LOGGER.info( "Instantiating trainer <%s>", cfg[keys.TRAINER]._target_ # pylint: disable=protected-access ) trainer: reax.Trainer = hydra.utils.instantiate( cfg[keys.TRAINER], listeners=listeners, logger=logger, default_root_dir=output_dir, ) _LOGGER.info( "Instantiating datamodule <%s>", cfg.data._target_ # pylint: disable=protected-access ) datamodule: reax.DataModule = hydra.utils.instantiate(cfg.data, _convert_="object") if cfg.get(keys.FROM_DATA): from_data_stage = from_data.FromData( # pylint: disable=no-member cfg[keys.FROM_DATA], trainer.engine, rngs=trainer.rngs, datamodule=datamodule ) stage = trainer.run(from_data_stage) print( "Calculated from data (these can be used in your config files using " "${from_data.<name>}:", ) utils.rich_utils.print_tree(stage.calculated, keys.FROM_DATA) # Save the configuration file here, this way things like inputs used to setup the model # will be baked into the input if output_dir is not None: with open(output_dir / config.DEFAULT_CONFIG_FILE, "w", encoding="utf-8") as file: file.write(omegaconf.OmegaConf.to_yaml(cfg, resolve=True)) _LOGGER.info( "Instantiating model <%s>", cfg[keys.MODEL]._target_, # pylint: disable=protected-access ) model: reax.Module = hydra.utils.instantiate(cfg[keys.MODEL], _convert_="object") object_dict = { "cfg": cfg, "datamodule": datamodule, "model": model, "listeners": listeners, "logger": logger, "trainer": trainer, } if logger: _LOGGER.info("Logging hyperparameters!") utils.log_hyperparameters(object_dict) # Fit the potential if cfg.get("train"): _LOGGER.info("Starting training!") trainer.fit( model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"), **cfg.get("train") ) train_metrics = trainer.listener_metrics if cfg.get("test"): _LOGGER.info("Starting testing!") ckpt_path = trainer.checkpoint_listener.best_model_path if ckpt_path == "": _LOGGER.warning("Best ckpt not found! Using current weights for testing...") ckpt_path = None trainer.test( model, datamodule=datamodule, ckpt_path=ckpt_path, ) _LOGGER.info("Best ckpt path: %s", ckpt_path) test_metrics = trainer.listener_metrics # merge train and test metrics metric_dict = {**train_metrics, **test_metrics} return metric_dict, object_dict
[docs] def main(cfg: omegaconf.DictConfig) -> float | None: """Main entry point for training. Args: cfg: DictConfig configuration composed by Hydra. Returns: Optional[float] with optimized metric value. """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) utils.extras(cfg) # train the model metric_dict, _ = train(cfg) # safely retrieve metric value for hydra-based hyperparameter optimization metric_value = utils.get_metric_value( metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") ) # return optimized metric return metric_value
if __name__ == "__main__": runner = hydra.main( version_base="1.3", config_path="../../configs", config_name=DEFAULT_TRAIN_FILE, )(main) runner()