Source code for tensorial.gcnn._diff

import abc
from collections.abc import Sequence
import dataclasses
import re
from typing import TYPE_CHECKING, Final, Protocol, Union

import flax.core
import jax
import jaxtyping as jt
import jraph

from . import _base, _tree
from .. import base

if TYPE_CHECKING:
    from tensorial import gcnn

__all__ = ("diff",)

DERIV_DELIMITER: Final[str] = ","
ArgumentSpecifier = Union[int, "gcnn.typing.TreePath"]


class DerivableGraphFunction(Protocol):
    def __call__(
        self, graph: jraph.GraphsTuple, *args: jt.PyTree
    ) -> jt.Array | tuple[jt.Array, jraph.GraphsTuple]: ...


[docs] @dataclasses.dataclass(frozen=True, slots=True) class GraphEntrySpec: """ A specification that identifies a particular entry in a hierarchical data structure (e.g., a PyTree) along with optional index labels used for differentiable computations. Attributes: key_path (Optional[gcnn.typing.TreePath]): A path to the target node in the PyTree, typically represented as a tuple of keys (e.g., strings or integers). indices (Optional[str]): A string representing symbolic indices, often used to annotate tensor dimensions for operations like differentiation. """ key_path: ArgumentSpecifier | None indices: str | None
[docs] @classmethod def create(cls, spec: "GraphEntrySpecLike") -> "GraphEntrySpec": if isinstance(spec, GraphEntrySpec): return spec if spec is None: return GraphEntrySpec(None, None) match = re.match(r"^(?:(.*?))?(?::(.*))?$", spec) if not match: raise ValueError(f"Could not parse the expression: {spec}") groups = match.groups() try: kee_path = int(groups[0]) except ValueError: kee_path = _tree.path_from_str(groups[0]) return GraphEntrySpec(kee_path, groups[1])
@property def safe_indices(self) -> str: return self.indices if self.indices is not None else "" def __str__(self) -> str: rep = [] if self.key_path is not None: key_path = ( str(self.key_path) if isinstance(self.key_path, int) else _tree.path_to_str(self.key_path) ) rep.append(key_path) if self.indices is not None: rep.append(f":{self.indices}") return "".join(rep) def __truediv__(self, other: "GraphEntrySpecLike") -> "SingleDerivative": return SingleDerivative.create(self, other)
[docs] def index_union(self, other: "GraphEntrySpec") -> str | None: # use dictionary keys as a set out = dict() if not self.indices else dict.fromkeys(self.indices) if other.indices: out |= dict.fromkeys(other.indices) if not out: return None return "".join(out)
[docs] def indices_intersection(self, other: "GraphEntrySpec") -> str: if not self.indices or not other.indices: return "" out = [index for index in self.indices if index in other.indices] return "".join(out)
GraphEntrySpecLike = str | GraphEntrySpec class Derivative(abc.ABC): @property @abc.abstractmethod def of(self) -> GraphEntrySpec: """Derivative of""" @property @abc.abstractmethod def graph_tuple_paths(self) -> "dict[gcnn.typing.TreePath, int]": ... @property @abc.abstractmethod def argnum_paths(self) -> dict[int, int]: ... @property @abc.abstractmethod def out(self) -> GraphEntrySpec: """Derivative output""" def evaluator( self, func: DerivableGraphFunction, return_graph: bool, argnum: int = 0, scale: float = 1.0, at: dict | None = None, ) -> "Evaluator": if at is not None: at = flax.core.FrozenDict(at) return Evaluator(func, self, return_graph, argnum, scale=scale, at=at) @abc.abstractmethod def build_derivative_fn( self, func: DerivableGraphFunction, return_graph: bool, argnum: int ) -> DerivableGraphFunction: """Get evaluate function from derivative""" def adapt(self, func: "gcnn.typing.ExGraphFunction") -> DerivableGraphFunction: return _base.adapt( func, *self.graph_tuple_paths, outs=tuple() if not self.of or not self.of.key_path else [self.of.key_path], return_graphs=True, ) def infer_out_indices(of: GraphEntrySpec, wrt: GraphEntrySpec) -> str: # All indices resulting from differentiating 'of' with respect to 'wrt' all_deriv_indices = of.safe_indices + wrt.safe_indices # Find indices common to both 'of' and 'wrt' — these will be reduced (summed over) shared_indices = of.indices_intersection(wrt) of_reduce = tuple(of.indices.index(i) for i in shared_indices) return "".join([idx for i, idx in enumerate(all_deriv_indices) if i not in of_reduce])
[docs] @dataclasses.dataclass(frozen=True, slots=True) class SingleDerivative(Derivative): _of: GraphEntrySpec _wrt: GraphEntrySpec _out: GraphEntrySpec # will be set in __post_init__ _pre_reduce: tuple[int, ...] = dataclasses.field(init=False) _post_reduce: tuple[int, ...] = dataclasses.field(init=False) _post_permute: tuple[int, ...] = dataclasses.field(init=False) _actual_out: GraphEntrySpec = dataclasses.field(init=False) def __post_init__(self): # All indices resulting from differentiating 'of' with respect to 'wrt' all_deriv_indices = self.of.safe_indices + self.wrt.safe_indices # now check that the output indices are some subset of these if self._out.indices is not None and not set(self._out.indices).issubset(all_deriv_indices): raise ValueError( f"The passed output indices {self._out.indices} include some that are " f"not in of ({self.of.indices}) or wrt ({self.wrt.indices})" ) # Find indices common to both 'of' and 'wrt' — these will be reduced (summed over) unreduced_indices = [] pre_reduce = [] if self.of.indices is not None: for i, index in enumerate(self.of.indices): index_not_in_out = self._out.indices is not None and index not in self._out.indices index_in_wrt = self.wrt.indices is not None and index in self.wrt.indices if index_not_in_out or index_in_wrt: pre_reduce.append(i) else: unreduced_indices.append(index) object.__setattr__(self, "_pre_reduce", tuple(pre_reduce)) num_after_pre_reduce = len(unreduced_indices) post_reduce = [] if self.wrt.indices is not None: for i, index in enumerate(self.wrt.indices): if self._out.indices is not None and index not in self._out.indices: post_reduce.append(i + num_after_pre_reduce) else: unreduced_indices.append(index) object.__setattr__(self, "_post_reduce", tuple(post_reduce)) # Map output indices to their new position in the output tensor post_permute = ( find_index_permutation(unreduced_indices, self._out.indices) if self._out.indices is not None else [] ) object.__setattr__(self, "_post_permute", tuple(post_permute)) if self._out.indices is None: out_indices = "".join(unreduced_indices) if post_permute: out_indices = "".join(out_indices for i in post_permute) out = GraphEntrySpec(self._out.key_path, out_indices) else: out = self._out object.__setattr__(self, "_actual_out", out)
[docs] @classmethod def create( cls, of: GraphEntrySpecLike, wrt: GraphEntrySpecLike, out: GraphEntrySpecLike | None = None, ) -> "SingleDerivative": of = GraphEntrySpec.create(of) wrt = GraphEntrySpec.create(wrt) out = GraphEntrySpec.create(out) return SingleDerivative(of, wrt, out)
@property def of(self) -> GraphEntrySpec: """Derivative of""" return self._of @property def wrt(self) -> GraphEntrySpec: """Derivative output""" return self._wrt @property def graph_tuple_paths(self) -> "dict[gcnn.typing.TreePath, int]": if isinstance(self._wrt.key_path, int): return {} return {self._wrt.key_path: 0} @property def argnum_paths(self) -> dict[int, int]: if not isinstance(self._wrt.key_path, int): return {} return {self._wrt.key_path: 0} @property def out(self) -> GraphEntrySpec: """Derivative output""" return self._actual_out def __str__(self) -> str: return f"∂{self.of}/∂{self.wrt}->{self.out}" def __truediv__(self, other: "GraphEntrySpecLike | SingleDerivative") -> "MultiDerivative": if isinstance(other, SingleDerivative): return MultiDerivative((self, other)) wrt = GraphEntrySpec.create(other) return MultiDerivative((self, SingleDerivative.create(self.out, wrt)))
[docs] def build_derivative_fn( self, func: DerivableGraphFunction, return_graph: bool, argnum: int ) -> DerivableGraphFunction: if not argnum >= 0: raise ValueError(f"argnum must be >= 0, got: {argnum}") if not self.out.indices: # Scalar valued diff_fn = jax.grad else: # Vector valued # note: we could use forward mode in cases where dim inputs << outputs # but the memory use is often exterme diff_fn = jax.jacrev def _diff_and_pre_process( graph: jraph.GraphsTuple, *args: jt.PyTree ) -> tuple[jt.Array, jraph.GraphsTuple]: value, graph = func(graph, *args) value, graph = self._pre_process(value, graph) return value, graph do_diff = diff_fn(_diff_and_pre_process, argnums=1 + argnum, has_aux=True) def _diff_fn( graph: jraph.GraphsTuple, *args: jt.PyTree ) -> jt.Array | tuple[jt.Array, jraph.GraphsTuple]: if len(args) <= argnum: raise ValueError( f"Derivative needs to be taken wrt argument {argnum}, " f"but only {len(args)} were passed. Did you forget to pass a value for " f"the value at which you would like the derivative to be evaluated?" ) self._check_shape("wrt", self.wrt, args[argnum]) value, graph = do_diff(graph, *args) value, graph = self._post_process(value, graph) if return_graph: return value, graph return value return _diff_fn
def _pre_process( self, value: jt.Array, graph: jraph.GraphsTuple ) -> tuple[jt.Array, jraph.GraphsTuple]: self._check_shape("of", self.of, value) if self._pre_reduce: value = base.as_array(value).sum(axis=self._pre_reduce) return value, graph def _post_process( self, value: jt.Array, graph: jraph.GraphsTuple ) -> tuple[jt.Array, jraph.GraphsTuple]: if self._post_reduce: value = base.as_array(value).sum(axis=self._post_reduce) if self._post_permute: value = value.transpose(self._post_permute) self._check_shape("out", self.out, value) return value, graph @staticmethod def _check_shape(stage: str, spec: GraphEntrySpec, value: jt.Array): if spec.indices is None: return # Nothing to check value = base.as_array(value) if len(spec.indices) != len(value.shape): raise ValueError( f"The passed '{stage}' indices `{spec.indices}` do not match the rank of the " f"value shape: {value.shape}" )
[docs] @dataclasses.dataclass(frozen=True, slots=True) class MultiDerivative(Derivative): parts: tuple[SingleDerivative, ...]
[docs] @classmethod def create( cls, of: GraphEntrySpecLike, wrt: str | Sequence[GraphEntrySpecLike], out: GraphEntrySpecLike | None = None, ) -> "MultiDerivative": if isinstance(wrt, str): wrt = wrt.split(",") if len(wrt) == 1: return MultiDerivative((SingleDerivative.create(of, wrt[0], out),)) # First parts = [SingleDerivative.create(of, wrt[0])] # Middle for part in wrt[1:-1]: parts.append(SingleDerivative.create(parts[-1].out, part)) # Last parts.append(SingleDerivative.create(parts[-1].out, wrt[-1], out)) return MultiDerivative(tuple(parts))
@property def of(self) -> GraphEntrySpec: return self.parts[0].of @property def graph_tuple_paths(self) -> "dict[gcnn.typing.TreePath, int]": paths: "dict[gcnn.typing.TreePath, int]" = {} wrt_map = [] for part in self.parts: if not isinstance(part.wrt.key_path, int): wrt_map.append(paths.setdefault(part.wrt.key_path, len(paths))) return paths @property def argnum_paths(self) -> dict[int, int]: start_idx = len(self.graph_tuple_paths) paths = {} for part in self.parts: if isinstance(part.wrt.key_path, int): paths.setdefault(part.wrt.key_path, start_idx + len(paths)) return paths @property def out(self) -> GraphEntrySpec: return self.parts[-1].out @property def paths(self) -> dict[ArgumentSpecifier, int]: paths = {} wrt_map = [] for part in self.parts: wrt_map.append(paths.setdefault(part.wrt.key_path, len(paths))) return paths def __str__(self) -> str: parts = [f"∂{self.of}/∂{self[0].wrt}"] for deriv in self[1:]: parts.append(f"∂{deriv.wrt}") parts.append(f"->{self.out}") return "".join(parts) def __len__(self) -> int: return len(self.parts) def __getitem__(self, item) -> SingleDerivative | tuple[SingleDerivative, ...]: return self.parts[item] def __iter__(self): yield from self.parts.__iter__()
[docs] def build_derivative_fn( self, func: DerivableGraphFunction, return_graph: bool, argnum: int ) -> DerivableGraphFunction: # Work our way from right to left creating the derivative evaluators argnums = [] for part in self: wrt_path = part.wrt.key_path if isinstance(wrt_path, int): argnums.append(self.argnum_paths[wrt_path]) else: argnums.append(self.graph_tuple_paths[wrt_path]) func = self[0].build_derivative_fn(func, return_graph=return_graph, argnum=argnums[0]) for part, argnum_ in zip(self[1:], argnums[1:]): func = part.build_derivative_fn(func, return_graph=return_graph, argnum=argnum_) return func
def process_paths(parts: Sequence[SingleDerivative]): paths: dict[gcnn.typing.TreePath, int] = {} wrt_map: list[int] = [] arg_parts = [] for part in parts: if isinstance(part.wrt.key_path, int): arg_parts.append(part.wrt.key_path) else: wrt_map.append(paths.setdefault(part.wrt.key_path, len(paths))) for part in parts: pass return paths @dataclasses.dataclass(frozen=True) class Evaluator: func: DerivableGraphFunction spec: Derivative return_graph: bool argnum: int scale: float = 1.0 at: flax.core.FrozenDict[str, jt.PyTree] | None = dataclasses.field(default=None, hash=False) # will be set in __post_init__ _evaluate_at: DerivableGraphFunction = dataclasses.field(init=False) def __post_init__(self): object.__setattr__( self, "_evaluate_at", self.spec.build_derivative_fn(self.func, return_graph=True, argnum=self.argnum), ) def __call__( self, graph: jraph.GraphsTuple, *args, **kwargs: dict[str, jt.PyTree] ) -> jt.Array | tuple[jt.Array, jraph.GraphsTuple]: values: dict[int, jt.PyTree] = {} for name, value in kwargs.items(): idx: int = self.spec.graph_tuple_paths[_tree.path_from_str(name)] values[idx] = value for name, value in (self.at or {}).items(): tree_path = _tree.path_from_str(name) idx: int = self.spec.graph_tuple_paths[tree_path] values[idx] = value diff_args = tuple(value for _, value in sorted(values.items())) + args value, graph_out = self._evaluate_at(graph, *diff_args) if self.spec.out.indices is not None and not len(value.shape) == len(self.spec.out.indices): raise ValueError( f"The output array rank ({len(value.shape)}) does not match the " f"passed number of output indices '{self.spec.out.indices}' " f"({len(self.spec.out.indices)})" ) value = self.scale * value if self.return_graph: return value, graph_out return value
[docs] def diff( *func_of, wrt: GraphEntrySpecLike | Sequence[GraphEntrySpecLike], of: GraphEntrySpecLike | None = None, out: GraphEntrySpecLike = None, scale: float = 1.0, at: dict | None = None, return_graph=False, ) -> Evaluator: """ Constructs a JAX-compatible evaluator for computing single or multiple derivatives of a scalar function (e.g., energy) defined over a Graph. This function acts as a factory, routing the request to create either a SingleDerivative or MultiDerivative object based on the `wrt` argument. The key feature of this function is the use of string-based tensor index notation (e.g., 'nodes.positions:Iγ') for specifying differentiation targets and output shape. Args: *func_of: Either (func), where 'func' is the energy function, or (func, of), where 'of' is the scalar entry (e.g., 'globals.energy') to differentiate. - func (Callable): The function (Graph -> Graph) whose output is differentiated. - of (GraphEntrySpecLike, optional): Specifies the scalar entry within the output graph of 'func' to differentiate. Defaults to the sole scalar output if omitted. wrt (GraphEntrySpecLike | Sequence[GraphEntrySpecLike]): The input entries of the graph with respect to which the derivative is taken. This must be a string or a sequence of strings, specifying the index notation for the input: - Example: 'nodes.positions:Iγ' means differentiate w.r.t. the $\\gamma$ component of the position of node $I$. out (GraphEntrySpecLike, optional): The index notation for the desired output tensor shape. This string defines the contraction of the indices from 'wrt'. - Example: If wrt=['field:α', 'field:β', 'positions:Iγ'], out=':Iγαβ' specifies a rank-4 tensor output with indices $I, \\gamma, \alpha, \beta$. If omitted, the indices are concatenated in the order they appear in 'wrt'. scale (float): A scalar factor to multiply the final derivative result by. Defaults to 1.0. at (dict | None): A dictionary mapping GraphEntrySpecLike strings (without indices) to jax.numpy arrays, specifying the value at which to evaluate the derivative for those entries. These entries will be held constant. - Example: {'globals.electric_field': jnp.zeros(3)} return_graph (bool): If True, the derivative tensor is packaged into a new Graph object under the name specified by the 'out' argument. If False, the function returns the raw derivative tensor. Defaults to False. Returns: Evaluator: A callable object that takes a Graph and returns the computed derivative tensor (or a Graph containing it). Raises: TypeError: If the arguments do not conform to the expected types. Note: The index notation used in 'wrt' and 'out' must adhere to the library's conventions for Graph entry keys and indices. For multi-derivatives, the number of unique indices in 'out' must match the number of indices in 'wrt'. """ if len(func_of) == 1: func = func_of[0] else: if of is not None: raise ValueError( "You need to call diff() as either diff(fn, 'globals.energy', ...) or " "diff(fn, of='globals.energy', ...) but you cannot provide both 2 positional " "arguments and `of`" ) func, of = func_of if isinstance(wrt, str): deriv = SingleDerivative.create(of, wrt, out) elif len(wrt) == 1: deriv = SingleDerivative.create(of, wrt[0], out) else: deriv = MultiDerivative.create(of, wrt, out) return deriv.evaluator(deriv.adapt(func), return_graph, scale=scale, at=at)
def ordered_unique_indices(lst): seen = {} return [i for i, x in enumerate(lst) if x not in seen and not seen.setdefault(x, True)] def find_index_permutation(lst_a, lst_b) -> list[int]: index_map = {value: idx for idx, value in enumerate(lst_a)} return [index_map[x] for x in lst_b]