Source code for tensorial.gcnn.atomic._metrics

from collections.abc import Mapping, Sequence

import beartype
import jax
import jax.numpy as jnp
import jaxtyping as jt
from jaxtyping import Bool, Float, Int
import jraph
from pytray import tree
import reax
from typing_extensions import override

from tensorial.typing import Array

from . import keys
from .. import graph_ops
from .. import keys as _keys
from .. import keys as graph_keys
from .. import metrics
from ... import nn_utils, utils

__all__ = (
    "AllAtomicNumbers",
    "NumSpecies",
    "ForceStd",
    "AvgNumNeighbours",
    "AvgNumNeighboursByAtomType",
    "TypeContributionLstsq",
    "EnergyContributionLstsq",
    "EnergyPerAtomLstsq",
)


def get(mapping: Mapping, key: str):
    try:
        return mapping[key]
    except KeyError:
        raise reax.exceptions.DataNotFound(f"Missing key: {key}") from None


AllAtomicNumbers = reax.metrics.Unique.from_fun(
    lambda graph, *_: (get(graph.nodes, keys.ATOMIC_NUMBERS), graph.nodes.get(graph_keys.MASK)),
    name="AtomicNumbers",
)


NumSpecies = reax.metrics.NumUnique.from_fun(
    lambda graph: (get(graph.nodes, keys.ATOMIC_NUMBERS), graph.nodes.get(graph_keys.MASK)),
    name="Species",
)


ForceStd = reax.metrics.Std.from_fun(
    lambda graph: (get(graph.nodes, keys.FORCES), graph.nodes.get(graph_keys.MASK)), name="Force"
)


AvgNumNeighbours = reax.metrics.Average.from_fun(
    lambda graph, *_: (
        jnp.bincount(graph.senders, length=graph.nodes[_keys.POSITIONS].shape[0]),
        graph.nodes.get(graph_keys.MASK),
    )
)


[docs] class EnergyPerAtomLstsq(reax.metrics.FromFun): """Calculate the least squares estimate of the energy per atom""" metric = reax.metrics.LeastSquaresEstimate
[docs] @staticmethod def func(graph, *_): return graph.n_node.reshape(-1, 1), graph.globals[keys.TOTAL_ENERGY].reshape(-1)
[docs] def compute(self) -> jax.Array: return super().compute().reshape(())
[docs] class TypeContributionLstsq(reax.metrics.Metric[Array]): """Online Least Squares Metric. Uses 'Sufficient Statistics' (XtX, Xty) to perform linear regression without storing the entire dataset history. """ # XtX: The Gram Matrix (A.T @ A) -> Shape: (n_types, n_types) xtx: Float[jax.Array, "n_types n_types"] | None = None # Xty: The Moment Vector (A.T @ b) -> Shape: (n_types, ...) xty: Float[jax.Array, "n_types ..."] | None = None @property def is_empty(self): return self.xtx is None
[docs] @classmethod @override def empty(cls) -> "TypeContributionLstsq": # pylint: disable=arguments-differ return cls()
[docs] @classmethod @jt.jaxtyped(typechecker=beartype.beartype) @override def create( # pylint: disable=arguments-differ cls, type_counts: Int[Array, "batch_size n_types"] | Float[Array, "batch_size n_types"], values: Float[Array, "batch_size ..."], mask: Bool[Array, "batch_size"] | None = None, /, ) -> "TypeContributionLstsq": np_ = utils.infer_backend(type_counts) # 1. Cast inputs to float for matrix operations a_mtx = type_counts.astype(np_.float32) b_vec = values # 2. Apply shape-stable masking # Instead of A[mask] (which changes shape), we use jnp.where. if mask is not None: # Broadcast mask: (batch,) -> (batch, 1) mask_expanded = mask[:, None] # CRITICAL: Use jnp.where instead of (A * mask). # If the masked-out rows in A contain NaNs, (NaN * 0) is still NaN. # jnp.where ensures safe zeros are used for ignored rows. a_mtx = jnp.where(mask_expanded, a_mtx, 0.0) # Handle masking for 'b' (values) # We align the mask dimensions to match b if b_vec.ndim > 1: # If b is (batch, targets), reshape mask to (batch, 1) mask_b = mask.reshape((mask.shape[0],) + (1,) * (b_vec.ndim - 1)) else: # If b is (batch,), standard mask works mask_b = mask b_vec = jnp.where(mask_b, b_vec, 0.0) # 3. Compute Sufficient Statistics for this batch # Since ignored rows are now exactly 0.0, they add nothing to the result # of the matrix multiplication, effectively filtering them out. # A.T @ A -> (n_types, n_types) batch_xtx = a_mtx.T @ a_mtx # A.T @ b -> (n_types, ...) batch_xty = a_mtx.T @ b_vec return cls(xtx=batch_xtx, xty=batch_xty)
[docs] @jt.jaxtyped(typechecker=beartype.beartype) @override def update( # pylint: disable=arguments-differ self, type_counts: Int[Array, "batch_size n_types"] | Float[Array, "batch_size n_types"], values: Float[Array, "batch_size ..."], mask: Bool[Array, "batch_size"] | None = None, /, ) -> "TypeContributionLstsq": # Calculate stats for the incoming batch batch_metric = self.create(type_counts, values, mask) if self.is_empty: return batch_metric # Accumulate: Simple element-wise addition of the matrices return TypeContributionLstsq( xtx=self.xtx + batch_metric.xtx, xty=self.xty + batch_metric.xty )
[docs] @override def merge(self, other: "TypeContributionLstsq") -> "TypeContributionLstsq": if self.is_empty: return other if other.is_empty: return self # Merging is just adding the sufficient statistics return TypeContributionLstsq(xtx=self.xtx + other.xtx, xty=self.xty + other.xty)
[docs] @override def compute(self, regularization: float = 1e-6): """Computes the fit using the Moore-Penrose pseudo-inverse approach. This naturally handles rank-deficient cases (like fixed 50:50 ratios) by producing the minimum-norm solution, which assigns equal contributions to indistinguishable types. """ if self.is_empty: raise RuntimeError("This metric is empty, cannot compute!") # Allow more mathsy names # pylint: disable=invalid-name np_ = utils.infer_backend(self.xtx) # Solve Normal Equation: (A.T A) x = A.T b # We solve for x in: xtx @ x = xty # 1. Since XtX is symmetric, eigh is more efficient and stable than SVD # s: eigenvalues, V: eigenvectors s, V = np_.linalg.eigh(self.xtx) # 2. Determine the threshold for 'zero' eigenvalues # Standard practice is a fraction of the largest eigenvalue max_s = np_.max(s) threshold = regularization * max_s # 3. Compute the pseudo-inverse of the eigenvalues # We only invert values above the threshold, others become 0.0 s_inv = np_.where(s > threshold, 1.0 / s, 0.0) # 4. Reconstruct the solution: x = V @ diag(s_inv) @ V.T @ xty # This is the Moore-Penrose solution (minimum L2 norm) # Equivalent to: x = pinv(xtx) @ xty weights = V @ (s_inv[:, None] * (V.T @ self.xty)) return weights
[docs] class EnergyContributionLstsq(reax.Metric): _type_map: Array _metric: TypeContributionLstsq | None = None def __init__(self, type_map: Sequence | Array, metric: TypeContributionLstsq = None): if type_map is None: raise ValueError("Must supply a value type_map") self._type_map = jnp.asarray(type_map) self._metric = metric
[docs] @override def empty(self) -> "EnergyContributionLstsq": if self._metric is None: return self return EnergyContributionLstsq(self._type_map)
[docs] @override def merge(self, other: "EnergyContributionLstsq") -> "EnergyContributionLstsq": if other._metric is None: # pylint: disable=protected-access return self if self._metric is None: return other return type(self)( type_map=self._type_map, metric=self._metric.merge(other._metric), # pylint: disable=protected-access )
[docs] @override def create( # pylint: disable=arguments-differ self, graphs: jraph.GraphsTuple, *_ ) -> "EnergyContributionLstsq": val = self._fun(graphs) # pylint: disable=not-callable return type(self)(type_map=self._type_map, metric=TypeContributionLstsq.create(*val))
[docs] @override def update( # pylint: disable=arguments-differ self, graphs: jraph.GraphsTuple, *_ ) -> "EnergyContributionLstsq": if self._metric is None: return self.create(graphs) val = self._fun(graphs) # pylint: disable=not-callable return EnergyContributionLstsq(type_map=self._type_map, metric=self._metric.update(*val))
[docs] @override def compute(self, regularization: float = 1e-6): if self._metric is None: raise RuntimeError("Nothing to compute, metric is empty!") return self._metric.compute(regularization=regularization)
@jt.jaxtyped(typechecker=beartype.beartype) def _fun(self, graphs: jraph.GraphsTuple, *_) -> tuple[ Float[Array, "batch_size k"], Float[Array, "batch_size 1"], Bool[Array, "batch_size"] | None, ]: graph_dict = graphs._asdict() num_nodes = graphs.n_node try: types = tree.get_by_path(graph_dict, ("nodes", keys.ATOMIC_NUMBERS)) except KeyError: raise reax.exceptions.DataNotFound( f"Missing key: {('nodes', keys.TOTAL_ENERGY)}" ) from None if self._type_map is None: num_classes = types.max().item() + 1 # Assume the types go 0,1,2...N else: # Transform the atomic numbers from whatever they are to 0, 1, 2.... types = nn_utils.vwhere(types, self._type_map) num_classes = len(self._type_map) one_hots = jax.nn.one_hot(types, num_classes) # TODO: make it so we don't need to set the value in the graph one_hot_field = ("type_one_hot",) tree.set_by_path(graphs.nodes, one_hot_field, one_hots) type_counts = graph_ops.graph_segment_reduce( graphs, ("nodes",) + one_hot_field, reduction="sum" ) # Predicting values try: values = tree.get_by_path(graph_dict, ("globals", keys.TOTAL_ENERGY)) except KeyError: raise reax.exceptions.DataNotFound( f"Missing key: {('globals', keys.TOTAL_ENERGY)}" ) from None if graph_keys.MASK in graph_dict["globals"]: mask = graph_dict["globals"][graph_keys.MASK] else: mask = None # Normalise by number of nodes type_counts = jax.vmap(lambda numer, denom: numer / denom, (0, 0))(type_counts, num_nodes) values = jax.vmap(lambda numer, denom: numer / denom, (0, 0))(values, num_nodes) return type_counts, values, mask
[docs] class AvgNumNeighboursByAtomType(metrics.AvgNumNeighboursByType): @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, atom_types: Sequence[int] | Int[Array, "n_types"], type_field: str = keys.ATOMIC_NUMBERS, state: metrics.AvgNumNeighboursByType.Averages | None = None, ): super().__init__(atom_types, type_field, state)