Source code for tensorial.reaxkit.utils.logging_utils

from typing import Any

import jax
from lightning_utilities.core import rank_zero
from omegaconf import OmegaConf

from . import pylogger

__all__ = ("log_hyperparameters",)

log = pylogger.RankedLogger(__name__, rank_zero_only=True)


[docs] @rank_zero.rank_zero_only def log_hyperparameters(object_dict: dict[str, Any]) -> None: """Controls which config parts are saved by Lightning loggers. Additionally, it saves: - Number of model parameters Args: object_dict: A dictionary containing the following objects: - `"cfg"`: A DictConfig object containing the main config. - `"model"`: The Lightning model. - `"trainer"`: The Lightning trainer. """ hparams = {} cfg = OmegaConf.to_container(object_dict["cfg"]) model = object_dict["model"] trainer = object_dict["trainer"] if not trainer.logger: log.warning("Logger not found! Skipping hyperparameter logging...") return hparams["model"] = cfg["model"] # save number of model parameters hparams["model/params/total"] = sum(x.size for x in jax.tree.leaves(model.parameters())) hparams["data"] = cfg["data"] hparams["trainer"] = cfg["trainer"] hparams["callbacks"] = cfg.get("callbacks") hparams["extras"] = cfg.get("extras") hparams["task_name"] = cfg.get("task_name") hparams["tags"] = cfg.get("tags") hparams["ckpt_path"] = cfg.get("ckpt_path") hparams["seed"] = cfg.get("seed") # send hparams to all loggers for logger in trainer.loggers: logger.log_hyperparams(hparams)