from collections.abc import Mapping
import functools
import logging
from typing import Any
import beartype
from flax import nnx
import hydra
import jaxtyping as jt
import jraph
import omegaconf
import reax
import reax.utils
from typing_extensions import override
_LOGGER = logging.getLogger(__name__)
DONE_KEY = "done"
__all__ = ("FromData",)
[docs]
class FromData(reax.stages.Stage):
"""A trainer stage that will populate an OmegaConf dictionary with data statistics calculated
from metrics.
"""
@jt.jaxtyped(typechecker=beartype.beartype)
def __init__(
self,
cfg: omegaconf.DictConfig,
engine: reax.Engine,
*,
rngs: nnx.Rngs | None = None,
dataloader: reax.DataLoader | None = None,
datamodule: reax.DataModule | None = None,
dataloader_name: str | None = "train",
ignore_missing: bool = True,
):
"""Populate a hydra configurations dictionary using calculated stats
Args:
cfg: the configuration dictionary
engine: the trainer strategy
rngs: the random number generator
dataloader: the dataloader to use
datamodule: if no dataloader is specified, a data module can
be used instead
dataloader_name: the datamodule dataloader name
ignore_missing: if `True`, any data that is needed to
calculate a metric but is missing will be ignored, and
that metric will not be calculated
"""
super().__init__(
"from_data",
module=None,
engine=engine,
rngs=rngs,
datamanager=reax.data.create_manager(
datamodule=datamodule, engine=engine, **{f"{dataloader_name}": dataloader}
),
)
# Params
self._dataset_name = dataloader_name
self._ignore_missing = ignore_missing
# State
self._cfg = cfg
self._dataloader = dataloader
self._calculated = {}
self._to_calculate: dict[str, Any] = self._update_stats(self._cfg)
self._metric_evaluator = None
@property
def dataloader(self) -> reax.DataLoader | None:
return self._datamanager.get_dataloader(self._dataset_name)
@property
def dataloaders(self) -> reax.DataLoader | None:
"""Dataloader function."""
return self.dataloader
@property
def calculated(self) -> dict[str, Any]:
"""The dictionary holding the calculated statistics"""
return self._calculated
[docs]
@override
def log(
self,
name: str,
value,
batch_size: int | None = None,
prog_bar: bool = False,
logger: bool = False,
on_step=False,
on_epoch=True,
) -> None:
self._child.log(
name,
value,
batch_size=batch_size,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
)
@override
def _on_started(self):
super()._on_started()
self._metric_evaluator = self._get_metric_evaluator()
@override
def _step(self) -> None:
eval_stats = reax.stages.EvaluateStats(
self._to_calculate,
self._datamanager,
self._engine,
rngs=self._engine.rngs,
dataset_name=self._dataset_name,
ignore_missing=True,
evaluator=self._metric_evaluator,
)
calculated: dict = self._run_child(eval_stats).logged_metrics
# Convert to types that can be used by omegaconf and update the configuration with the
# values
calculated = {label: reax.utils.arrays.to_base(stat) for label, stat in calculated.items()}
if self._ignore_missing:
# Set any that we couldn't calculate to `None`
for missing in self._to_calculate.keys() - calculated.keys():
calculated[missing] = None
# Update for the next step
self._cfg.update(calculated)
self._calculated.update(calculated)
# Find the next set to get calculated
self._to_calculate = self._update_stats(self._cfg)
if not self._to_calculate:
# we're done
self._cfg.update(self._calculated)
self._stopper.set()
def _update_stats(self, from_data: Mapping) -> dict:
with_dependencies = []
# Find those that we will come back to for a second path
for entry in find_iterpol(from_data):
with_dependencies.append(entry[0][0])
with_dependencies = set(with_dependencies)
to_calculate = {}
for label, value in from_data.items():
if label in self._calculated:
continue
if label in with_dependencies:
continue
if omegaconf.OmegaConf.is_dict(value):
stat = hydra.utils.instantiate(value, _convert_="object")
else:
stat = reax.metrics.get(value)
to_calculate[label] = stat
return to_calculate
def _get_metric_evaluator(self) -> reax.metrics.MetricEvaluator:
batch = next(iter(self.dataloader))
example_data = batch
if isinstance(batch, tuple):
example_data = batch[0]
if isinstance(example_data, jraph.GraphsTuple) and len(example_data.n_node.shape) > 1:
# We have batched graphs
_LOGGER.debug(
"FromData staging using vmap metric evaluator as len(example_data.n_node.shape)=%d",
len(example_data.n_node.shape),
)
return reax.metrics.VmapEvaluator()
return reax.metrics.DefaultEvaluator()
@functools.singledispatch
def _to_omega(value):
return value
@_to_omega.register
def _(value: dict) -> dict:
return {key: _to_omega(value) for key, value in value.items()}
@_to_omega.register
def _(value: jt.Array) -> int | float | list:
return reax.utils.arrays.to_base(value)
def find_iterpol(root, path=()):
for key, value in root.items():
if isinstance(value, str):
if omegaconf.OmegaConf.is_interpolation(root, key):
if not isinstance(root[key], (int, float, list)):
yield path, key
elif omegaconf.OmegaConf.is_dict(value):
yield from find_iterpol(value, (key,))