from collections.abc import Mapping, Sequence
import beartype
import jax
import jax.numpy as jnp
import jaxtyping as jt
import jraph
from pytray import tree
import reax
from typing_extensions import override
from tensorial.typing import Array
from . import keys
from .. import graph_ops
from .. import keys as graph_keys
from .. import metrics
from ... import nn_utils, utils
__all__ = (
"AllAtomicNumbers",
"NumSpecies",
"ForceStd",
"AvgNumNeighbours",
"AvgNumNeighboursByAtomType",
"TypeContributionLstsq",
"EnergyContributionLstsq",
"EnergyPerAtomLstsq",
)
def get(mapping: Mapping, key: str):
try:
return mapping[key]
except KeyError:
raise reax.exceptions.DataNotFound(f"Missing key: {key}") from None
AllAtomicNumbers = reax.metrics.Unique.from_fun(
lambda graph, *_: (get(graph.nodes, keys.ATOMIC_NUMBERS), graph.nodes.get(graph_keys.MASK)),
name="AtomicNumbers",
)
NumSpecies = reax.metrics.NumUnique.from_fun(
lambda graph: (get(graph.nodes, keys.ATOMIC_NUMBERS), graph.nodes.get(graph_keys.MASK)),
name="Species",
)
ForceStd = reax.metrics.Std.from_fun(
lambda graph: (get(graph.nodes, keys.FORCES), graph.nodes.get(graph_keys.MASK)), name="Force"
)
AvgNumNeighbours = reax.metrics.Average.from_fun(
lambda graph, *_: (
jnp.bincount(graph.senders, length=jnp.sum(graph.n_node)),
graph.nodes.get(graph_keys.MASK),
)
)
[docs]
class EnergyPerAtomLstsq(reax.metrics.FromFun):
"""Calculate the least squares estimate of the energy per atom"""
metric = reax.metrics.LeastSquaresEstimate
[docs]
@staticmethod
def func(graph, *_):
return graph.n_node.reshape(-1, 1), graph.globals[keys.TOTAL_ENERGY].reshape(-1)
[docs]
def compute(self) -> jax.Array:
return super().compute().reshape(())
[docs]
class TypeContributionLstsq(reax.metrics.Metric[Array]):
"""Online Least Squares Metric.
Uses 'Sufficient Statistics' (XtX, Xty) to perform linear regression
without storing the entire dataset history.
"""
# XtX: The Gram Matrix (A.T @ A) -> Shape: (n_types, n_types)
xtx: jt.Float[jax.Array, "n_types n_types"] | None = None
# Xty: The Moment Vector (A.T @ b) -> Shape: (n_types, ...)
xty: jt.Float[jax.Array, "n_types ..."] | None = None
@property
def is_empty(self):
return self.xtx is None
[docs]
@classmethod
@override
def empty(cls) -> "TypeContributionLstsq": # pylint: disable=arguments-differ
return cls()
[docs]
@classmethod
@jt.jaxtyped(typechecker=beartype.beartype)
@override
def create( # pylint: disable=arguments-differ
cls,
type_counts: jt.Int[Array, "batch_size n_types"] | jt.Float[Array, "batch_size n_types"],
values: jt.Float[Array, "batch_size ..."],
mask: jt.Bool[Array, "batch_size"] | None = None,
/,
) -> "TypeContributionLstsq":
np_ = utils.infer_backend(type_counts)
# 1. Cast inputs to float for matrix operations
a_mtx = type_counts.astype(np_.float32)
b_vec = values
# 2. Apply shape-stable masking
# Instead of A[mask] (which changes shape), we use jnp.where.
if mask is not None:
# Broadcast mask: (batch,) -> (batch, 1)
mask_expanded = mask[:, None]
# CRITICAL: Use jnp.where instead of (A * mask).
# If the masked-out rows in A contain NaNs, (NaN * 0) is still NaN.
# jnp.where ensures safe zeros are used for ignored rows.
a_mtx = jnp.where(mask_expanded, a_mtx, 0.0)
# Handle masking for 'b' (values)
# We align the mask dimensions to match b
if b_vec.ndim > 1:
# If b is (batch, targets), reshape mask to (batch, 1)
mask_b = mask.reshape((mask.shape[0],) + (1,) * (b_vec.ndim - 1))
else:
# If b is (batch,), standard mask works
mask_b = mask
b_vec = jnp.where(mask_b, b_vec, 0.0)
# 3. Compute Sufficient Statistics for this batch
# Since ignored rows are now exactly 0.0, they add nothing to the result
# of the matrix multiplication, effectively filtering them out.
# A.T @ A -> (n_types, n_types)
batch_xtx = a_mtx.T @ a_mtx
# A.T @ b -> (n_types, ...)
batch_xty = a_mtx.T @ b_vec
return cls(xtx=batch_xtx, xty=batch_xty)
[docs]
@jt.jaxtyped(typechecker=beartype.beartype)
@override
def update( # pylint: disable=arguments-differ
self,
type_counts: jt.Int[Array, "batch_size n_types"] | jt.Float[Array, "batch_size n_types"],
values: jt.Float[Array, "batch_size ..."],
mask: jt.Bool[Array, "batch_size"] | None = None,
/,
) -> "TypeContributionLstsq":
# Calculate stats for the incoming batch
batch_metric = self.create(type_counts, values, mask)
if self.is_empty:
return batch_metric
# Accumulate: Simple element-wise addition of the matrices
return TypeContributionLstsq(
xtx=self.xtx + batch_metric.xtx, xty=self.xty + batch_metric.xty
)
[docs]
@override
def merge(self, other: "TypeContributionLstsq") -> "TypeContributionLstsq":
if self.is_empty:
return other
if other.is_empty:
return self
# Merging is just adding the sufficient statistics
return TypeContributionLstsq(xtx=self.xtx + other.xtx, xty=self.xty + other.xty)
[docs]
@override
def compute(self, regularization: float = 1e-6):
if self.is_empty:
raise RuntimeError("This metric is empty, cannot compute!")
np_ = utils.infer_backend(self.xtx)
# Solve Normal Equation: (A.T A) x = A.T b
# We solve for x in: xtx @ x = xty
# Add small ridge regularization for numerical stability
# (prevents crash if matrix is singular or data was empty)
eye = np_.eye(self.xtx.shape[0])
safe_xtx = self.xtx + (eye * regularization)
return np_.linalg.solve(safe_xtx, self.xty)
[docs]
class EnergyContributionLstsq(reax.Metric):
_type_map: jt.Array
_metric: TypeContributionLstsq | None = None
def __init__(self, type_map: Sequence | Array, metric: TypeContributionLstsq = None):
if type_map is None:
raise ValueError("Must supply a value type_map")
self._type_map = jnp.asarray(type_map)
self._metric = metric
[docs]
@override
def empty(self) -> "EnergyContributionLstsq":
if self._metric is None:
return self
return EnergyContributionLstsq(self._type_map)
[docs]
@override
def merge(self, other: "EnergyContributionLstsq") -> "EnergyContributionLstsq":
if other._metric is None: # pylint: disable=protected-access
return self
if self._metric is None:
return other
return type(self)(
type_map=self._type_map,
metric=self._metric.merge(other._metric), # pylint: disable=protected-access
)
[docs]
@override
def create( # pylint: disable=arguments-differ
self, graphs: jraph.GraphsTuple, *_
) -> "EnergyContributionLstsq":
val = self._fun(graphs) # pylint: disable=not-callable
return type(self)(type_map=self._type_map, metric=TypeContributionLstsq.create(*val))
[docs]
@override
def update( # pylint: disable=arguments-differ
self, graphs: jraph.GraphsTuple, *_
) -> "EnergyContributionLstsq":
if self._metric is None:
return self.create(graphs)
val = self._fun(graphs) # pylint: disable=not-callable
return EnergyContributionLstsq(type_map=self._type_map, metric=self._metric.update(*val))
[docs]
@override
def compute(self):
if self._metric is None:
raise RuntimeError("Nothing to compute, metric is empty!")
return self._metric.compute()
@jt.jaxtyped(typechecker=beartype.beartype)
def _fun(self, graphs: jraph.GraphsTuple, *_) -> tuple[
jt.Float[Array, "batch_size k"],
jt.Float[Array, "batch_size 1"],
jt.Bool[Array, "batch_size"] | None,
]:
graph_dict = graphs._asdict()
num_nodes = graphs.n_node
try:
types = tree.get_by_path(graph_dict, ("nodes", keys.ATOMIC_NUMBERS))
except KeyError:
raise reax.exceptions.DataNotFound(
f"Missing key: {('nodes', keys.TOTAL_ENERGY)}"
) from None
if self._type_map is None:
num_classes = types.max().item() + 1 # Assume the types go 0,1,2...N
else:
# Transform the atomic numbers from whatever they are to 0, 1, 2....
types = nn_utils.vwhere(types, self._type_map)
num_classes = len(self._type_map)
one_hots = jax.nn.one_hot(types, num_classes)
# TODO: make it so we don't need to set the value in the graph
one_hot_field = ("type_one_hot",)
tree.set_by_path(graphs.nodes, one_hot_field, one_hots)
type_counts = graph_ops.graph_segment_reduce(
graphs, ("nodes",) + one_hot_field, reduction="sum"
)
# Predicting values
try:
values = tree.get_by_path(graph_dict, ("globals", keys.TOTAL_ENERGY))
except KeyError:
raise reax.exceptions.DataNotFound(
f"Missing key: {('globals', keys.TOTAL_ENERGY)}"
) from None
if graph_keys.MASK in graph_dict["globals"]:
mask = graph_dict["globals"][graph_keys.MASK]
else:
mask = None
# Normalise by number of nodes
type_counts = jax.vmap(lambda numer, denom: numer / denom, (0, 0))(type_counts, num_nodes)
values = jax.vmap(lambda numer, denom: numer / denom, (0, 0))(values, num_nodes)
return type_counts, values, mask
[docs]
class AvgNumNeighboursByAtomType(metrics.AvgNumNeighboursByType):
@jt.jaxtyped(typechecker=beartype.beartype)
def __init__(
self,
atom_types: Sequence[int] | jt.Int[Array, "n_types"],
type_field: str = keys.ATOMIC_NUMBERS,
state: metrics.AvgNumNeighboursByType.Averages | None = None,
):
super().__init__(atom_types, type_field, state)