from collections.abc import Callable, Sequence
import functools
from typing import TYPE_CHECKING, Any
import beartype
from flax import linen
import jax
import jax.numpy as jnp
import jaxtyping as jt
import jraph
from pytray import tree
from . import _base, _tree
from .. import base
if TYPE_CHECKING:
import tensorial
from tensorial import gcnn
__all__ = ("grad", "jacobian", "jacrev", "jacfwd", "hessian", "Grad", "Jacobian", "Jacfwd")
TreePath = tuple[Any, ...]
GradOut = jraph.GraphsTuple | jt.PyTree | tuple[jt.PyTree]
@jt.jaxtyped(typechecker=beartype.beartype)
def grad_shim(
fn: "gcnn.typing.GraphFunction",
graph: jraph.GraphsTuple,
of: tuple,
paths: tuple["gcnn.typing.TreePathLike"],
*wrt_variables,
) -> tuple[jax.Array, jraph.GraphsTuple]:
def repl(path, val):
try:
idx = paths.index(tuple(map(_tree.key_to_str, path)))
return wrt_variables[idx]
except ValueError:
return val
graph = jax.tree_util.tree_map_with_path(repl, graph)
# Pass the graph through the original function
out_graph = fn(graph)
# Extract the quantity that we want to differentiate
return jnp.sum(base.as_array(tree.get_by_path(out_graph._asdict(), of))), out_graph
def _create_grad_shim(
fn: "gcnn.typing.GraphFunction",
of: "gcnn.TreePathLike",
wrt: "Sequence[gcnn.typing.TreePathLike]",
sum_axis: bool | int | None = None,
) -> "Callable[[jraph.GraphsTuple, ...], tuple[tensorial.typing.ArrayType, jraph.GraphsTuple]]":
"""Create a function that takes the values of the quantities we want to take the derivatives
with respect to
"""
if of is not None and len(of) < 2:
raise ValueError(f"of must be of at least length two e.g. ('globals', 'entry'), got: {of}")
def shim(
graph: jraph.GraphsTuple, *args
) -> "tuple[tensorial.typing.ArrayType, jraph.GraphsTuple]":
new_fn = _base.transform_fn(fn, *wrt, outs=[of], return_graphs=True)
# Pass the graph through the function
value, graph_out = new_fn(graph, *args)
value = base.as_array(value)
if sum_axis is not False:
value = value.sum(axis=sum_axis)
return value, graph_out
return shim
def _graph_autodiff(
diff_fn: Callable,
func: "gcnn.typing.GraphFunction",
of: "gcnn.typing.TreePathLike",
wrt: "str | Sequence[gcnn.typing.TreePathLike]",
sign: float = 1.0,
sum_axis=None,
has_aux: bool = False,
) -> Callable[[jraph.GraphsTuple], GradOut]:
# Gradient of
of = _tree.path_from_str(of)
# Gradient with respect to
wrt: tuple[gcnn.TreePath, ...] = _tree.to_paths(wrt)
# Creat the shim which will be a function that takes the graph as first argument, and
# the remaining values are the values to take the gradient at
shim = _create_grad_shim(func, of, wrt, sum_axis=sum_axis)
grad_fn = diff_fn(shim, argnums=tuple(range(1, len(wrt) + 1)), has_aux=True)
# Evaluate
def calc_grad(graph: jraph.GraphsTuple, *wrt_values) -> GradOut:
if len(wrt_values) != len(wrt):
raise ValueError(
f"Failed to supply valued to evaluate derivatives at, expected: "
f"{','.join(map(_tree.path_to_str, wrt))}"
)
grads, graph_out = grad_fn(graph, *wrt_values)
grads = [sign * grad for grad in grads]
if len(wrt_values) == 1:
grads = grads[0]
if has_aux:
return grads, graph_out
return grads
return calc_grad
[docs]
def grad(
of: "gcnn.TreePathLike",
wrt: "gcnn.TreePathLike | Sequence[gcnn.TreePathLike]",
sign: float = 1.0,
has_aux: bool = False,
) -> Callable[["gcnn.GraphFunction"], Callable[[jraph.GraphsTuple, ...], GradOut]]:
"""Build a partially initialised Grad function whose only
Args:
kwargs: accepts any arguments that `Grad` does
Returns:
the partially initialized Grad function
"""
return functools.partial(_graph_autodiff, jax.grad, of=of, wrt=wrt, sign=sign, has_aux=has_aux)
[docs]
def jacrev(
of: "gcnn.TreePathLike",
wrt: "gcnn.TreePathLike | Sequence[gcnn.TreePathLike]",
sign: float = 1.0,
has_aux: bool = False,
) -> Callable[["gcnn.GraphFunction"], Callable[[jraph.GraphsTuple, ...], GradOut]]:
"""Build a partially initialised Grad function whose only
Args:
kwargs: accepts any arguments that `Grad` does
Returns:
the partially initialized Grad function
"""
return functools.partial(
_graph_autodiff, jax.jacrev, of=of, wrt=wrt, sign=sign, sum_axis=0, has_aux=has_aux
)
[docs]
def jacfwd(
of: "gcnn.TreePathLike",
wrt: "gcnn.typing.TreePathLike | Sequence[gcnn.typing.TreePathLike]",
sign: float = 1.0,
has_aux: bool = False,
) -> Callable[["gcnn.typing.GraphFunction"], Callable[[jraph.GraphsTuple, ...], GradOut]]:
"""Build a partially initialised Grad function whose only
Args:
kwargs: accepts any arguments that `Grad` does
Returns:
the partially initialized Grad function
"""
return functools.partial(
_graph_autodiff, jax.jacfwd, of=of, wrt=wrt, sign=sign, sum_axis=0, has_aux=has_aux
)
jacobian = jacrev
[docs]
def hessian(
of: "gcnn.TreePathLike",
wrt: "gcnn.TreePathLike | Sequence[gcnn.TreePathLike]",
sign: float = 1.0,
has_aux: bool = False,
) -> Callable[["gcnn.GraphFunction"], Callable[[jraph.GraphsTuple, ...], GradOut]]:
"""Build a partially initialised Grad function whose only
Args:
kwargs: accepts any arguments that `Grad` does
Returns:
the partially initialized Grad function
"""
return functools.partial(
_graph_autodiff, jax.hessian, of=of, wrt=wrt, sign=sign, sum_axis=None, has_aux=has_aux
)
[docs]
class Grad(linen.Module):
"""
The `Grad` class computes gradients of graph-based functions with respect to specified graph
attributes. It enables automatic differentiation of operations defined on graph structures,
such as computing how changes in node positions affect edge lengths or other graph properties.
The class supports both scalar and vector-valued gradients and integrates with JAX for efficient
computation.
"""
func: "gcnn.typing.GraphFunction"
of: "gcnn.typing.TreePathLike" # Gradient of
wrt: "gcnn.TreePathLike | list[gcnn.TreePathLike]" # Gradient with respect to
out_field: "str | gcnn.TreePathLike | list[gcnn.TreePathLike]" = "auto"
sign: float = 1.0
[docs]
def setup(self):
# pylint: disable=attribute-defined-outside-init
self._of = _tree.to_paths(self.of)[0]
self._wrt = _tree.to_paths(self.wrt)
self._out_field = _out_derivative_keys(self._of, self._wrt, self.out_field)
self._grad_fn = grad(self._of, self._wrt, sign=self.sign, has_aux=True)(self.func)
@_base.shape_check
def __call__(self, graph: jraph.GraphsTuple) -> GradOut:
wrt = _tree.get(graph, *self._wrt)
if len(self._wrt) == 1:
wrt = [wrt]
res, out_graph = self._grad_fn(graph, *wrt)
if len(self._wrt) == 1:
res = [res]
graph_updates = out_graph._asdict()
for path, value in zip(self._out_field, res):
tree.set_by_path(graph_updates, path, value)
return jraph.GraphsTuple(**graph_updates)
[docs]
class Jacobian(linen.Module):
func: "gcnn.typing.GraphFunction"
of: "gcnn.typing.TreePathLike"
wrt: str | Sequence["gcnn.typing.TreePathLike"]
out_field: str | Sequence[str] = "auto"
sign: float = 1.0
[docs]
def setup(self):
# pylint: disable=attribute-defined-outside-init
self._of = _tree.to_paths(self.of)[0]
self._wrt = _tree.to_paths(self.wrt)
self._out_field = _out_derivative_keys(self._of, self._wrt, self.out_field)
self._grad_fn = jacobian(self.of, self.wrt, self.sign)(self.func)
@_base.shape_check
def __call__(self, graph: jraph.GraphsTuple) -> GradOut:
wrt = _tree.get(graph, *self._wrt)
if len(self._wrt) == 1:
wrt = [wrt]
res = self._grad_fn(graph, *wrt)
graph_updates = graph._asdict()
for path, value in zip(self._out_field, res):
tree.set_by_path(graph_updates, path, value)
return jraph.GraphsTuple(**graph_updates)
[docs]
class Jacfwd(linen.Module):
func: "gcnn.typing.GraphFunction"
of: "gcnn.typing.TreePathLike"
wrt: str | Sequence["gcnn.typing.TreePathLike"]
out_field: str | Sequence[str] = "auto"
sign: float = 1.0
[docs]
def setup(self):
# pylint: disable=attribute-defined-outside-init
self._of = _tree.to_paths(self.of)[0]
self._wrt = _tree.to_paths(self.wrt)
self._out_field = _out_derivative_keys(self._of, self._wrt, self.out_field)
self._grad_fn = jacfwd(self.of, self.wrt, self.sign)(self.func)
@_base.shape_check
def __call__(self, graph: jraph.GraphsTuple) -> GradOut:
wrt = _tree.get(graph, *self._wrt)
if len(self._wrt) == 1:
wrt = [wrt]
res = self._grad_fn(graph, *wrt)
graph_updates = graph._asdict()
for path, value in zip(self._out_field, res):
tree.set_by_path(graph_updates, path, value)
return jraph.GraphsTuple(**graph_updates)
def _out_derivative_keys(
of: "gcnn.TreePath", wrt: "Sequence[gcnn.TreePath]", out_key
) -> "tuple[gcnn.TreePath, ...]":
if out_key == "auto":
derivs = []
for wrt_entry in wrt:
derivs.append(wrt_entry[:-1] + (f"d{'.'.join(of[1:])}/d{wrt_entry[-1]}",))
return tuple(derivs)
if not isinstance(out_key, list):
out_key = [out_key]
return _tree.to_paths(out_key)
def _create(of: "gcnn.TreePath", wrt: Sequence[tuple]) -> "list[gcnn.TreePath]":
derivs = []
for wrt_entry in wrt:
derivs.append(wrt_entry[:-1] + (f"d{'.'.join(of[1:])}/d{wrt_entry[-1]}",))
return derivs