Source code for tensorial.reaxkit.evaluate
from typing import TYPE_CHECKING
import hydra
import omegaconf
import reax
from . import config, keys, utils
if TYPE_CHECKING:
from tensorial import reaxkit
_LOGGER = utils.RankedLogger(__name__, rank_zero_only=True)
DEFAULT_EVAL_FILE = "eval.yaml"
@utils.task_wrapper
def evaluate(cfg: omegaconf.DictConfig) -> None:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg: DictConfig configuration composed by Hydra.
"""
_LOGGER.info(
"Instantiating datamodule <%s>",
cfg[keys.DATA]._target_, # pylint: disable=protected-access
)
datamodule: reax.DataModule = hydra.utils.instantiate(cfg.data)
_LOGGER.info("Instantiating loggers...")
logger: list[reax.Logger] = utils.instantiate_loggers(cfg.get(keys.LOGGER))
_LOGGER.info(
"Instantiating trainer <%s>",
cfg[keys.TRAINER]._target_, # pylint: disable=protected-access
)
trainer: reax.Trainer = hydra.utils.instantiate(cfg[keys.TRAINER], logger=logger)
model: "reaxkit.ReaxModule" = config.load_module(cfg[keys.CONFIG_PATH], cfg[keys.CKPT_PATH])
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"logger": logger,
"trainer": trainer,
}
if logger:
_LOGGER.info("Logging hyperparameters")
utils.log_hyperparameters(object_dict)
if cfg.get(keys.VALIDATION):
_LOGGER.info("Starting validation")
trainer.validate(
model,
datamodule=datamodule,
ckpt_path=cfg[keys.CKPT_PATH],
**cfg.get(keys.VALIDATION, {}),
)
if cfg.get(keys.TEST):
_LOGGER.info("Starting testing")
trainer.test(
model, datamodule=datamodule, ckpt_path=cfg[keys.CKPT_PATH], **cfg.get(keys.TEST, {})
)
if cfg.get(keys.PREDICT):
_LOGGER.info("Starting prediction")
trainer.predict(
model, datamodule=datamodule, ckpt_path=cfg[keys.CKPT_PATH], **cfg.get(keys.PREDICT, {})
)
[docs]
def main(cfg: omegaconf.DictConfig) -> None:
"""Main entry point for evaluation.
Args:
cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
evaluate(cfg)
if __name__ == "__main__":
runner = hydra.main(
version_base="1.3",
config_path="../../configs",
config_name=DEFAULT_EVAL_FILE,
)(main)
runner()