Source code for tensorial.reaxkit.from_data

from collections.abc import Mapping
import functools
import logging
from typing import Any

import beartype
from flax import nnx
import hydra
import jaxtyping as jt
import jraph
import omegaconf
import reax
import reax.utils
from typing_extensions import override

_LOGGER = logging.getLogger(__name__)

DONE_KEY = "done"

__all__ = ("FromData",)


[docs] class FromData(reax.stages.Stage): """A trainer stage that will populate an OmegaConf dictionary with data statistics calculated from metrics. """ @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, cfg: omegaconf.DictConfig, engine: reax.Engine, *, rngs: nnx.Rngs | None = None, dataloader: reax.DataLoader | None = None, datamodule: reax.DataModule | None = None, dataloader_name: str | None = "train", ignore_missing: bool = True, ): """Populate a hydra configurations dictionary using calculated stats Args: cfg: the configuration dictionary engine: the trainer strategy rngs: the random number generator dataloader: the dataloader to use datamodule: if no dataloader is specified, a data module can be used instead dataloader_name: the datamodule dataloader name ignore_missing: if `True`, any data that is needed to calculate a metric but is missing will be ignored, and that metric will not be calculated """ super().__init__( "from_data", module=None, engine=engine, rngs=rngs, datamanager=reax.data.create_manager( datamodule=datamodule, engine=engine, **{f"{dataloader_name}": dataloader} ), ) # Params self._dataset_name = dataloader_name self._ignore_missing = ignore_missing # State self._cfg = cfg self._dataloader = dataloader self._calculated = {} self._to_calculate: dict[str, Any] = self._update_stats(self._cfg) self._metric_evaluator = None @property def dataloader(self) -> reax.DataLoader | None: return self._datamanager.get_dataloader(self._dataset_name) @property def dataloaders(self) -> reax.DataLoader | None: """Dataloader function.""" return self.dataloader @property def calculated(self) -> dict[str, Any]: """The dictionary holding the calculated statistics""" return self._calculated
[docs] @override def log( self, name: str, value, batch_size: int | None = None, prog_bar: bool = False, logger: bool = False, on_step=False, on_epoch=True, ) -> None: self._child.log( name, value, batch_size=batch_size, prog_bar=prog_bar, logger=logger, on_step=on_step, on_epoch=on_epoch, )
@override def _on_started(self): super()._on_started() self._metric_evaluator = self._get_metric_evaluator() @override def _step(self) -> None: eval_stats = reax.stages.EvaluateStats( self._to_calculate, self._datamanager, self._engine, rngs=self._engine.rngs, dataset_name=self._dataset_name, ignore_missing=True, evaluator=self._metric_evaluator, ) calculated: dict = self._run_child(eval_stats).logged_metrics # Convert to types that can be used by omegaconf and update the configuration with the # values calculated = {label: reax.utils.arrays.to_base(stat) for label, stat in calculated.items()} if self._ignore_missing: # Set any that we couldn't calculate to `None` for missing in self._to_calculate.keys() - calculated.keys(): calculated[missing] = None # Update for the next step self._cfg.update(calculated) self._calculated.update(calculated) # Find the next set to get calculated self._to_calculate = self._update_stats(self._cfg) if not self._to_calculate: # we're done self._cfg.update(self._calculated) self._stopper.set() def _update_stats(self, from_data: Mapping) -> dict: with_dependencies = [] # Find those that we will come back to for a second path for entry in find_iterpol(from_data): with_dependencies.append(entry[0][0]) with_dependencies = set(with_dependencies) to_calculate = {} for label, value in from_data.items(): if label in self._calculated: continue if label in with_dependencies: continue if omegaconf.OmegaConf.is_dict(value): stat = hydra.utils.instantiate(value, _convert_="object") else: stat = reax.metrics.get(value) to_calculate[label] = stat return to_calculate def _get_metric_evaluator(self) -> reax.metrics.MetricEvaluator: batch = next(iter(self.dataloader)) example_data = batch if isinstance(batch, tuple): example_data = batch[0] if isinstance(example_data, jraph.GraphsTuple) and len(example_data.n_node.shape) > 1: # We have batched graphs _LOGGER.debug( "FromData staging using vmap metric evaluator as len(example_data.n_node.shape)=%d", len(example_data.n_node.shape), ) return reax.metrics.VmapEvaluator() return reax.metrics.DefaultEvaluator()
@functools.singledispatch def _to_omega(value): return value @_to_omega.register def _(value: dict) -> dict: return {key: _to_omega(value) for key, value in value.items()} @_to_omega.register def _(value: jt.Array) -> int | float | list: return reax.utils.arrays.to_base(value) def find_iterpol(root, path=()): for key, value in root.items(): if isinstance(value, str): if omegaconf.OmegaConf.is_interpolation(root, key): if not isinstance(root[key], (int, float, list)): yield path, key elif omegaconf.OmegaConf.is_dict(value): yield from find_iterpol(value, (key,))