Source code for tensorial.gcnn._base
from collections.abc import Sequence
import functools
import logging
from typing import TYPE_CHECKING, Protocol
from jax import tree_util
import jaxtyping as jt
import jraph
from . import _tree
from .experimental import utils as exp_utils
from .typing import GraphFunction
if TYPE_CHECKING:
from tensorial import gcnn
__all__ = "GraphFunction", "shape_check", "adapt", "transform_fn"
_LOGGER = logging.getLogger(__name__)
[docs]
def shape_check(func: "gcnn.typing.GraphFunction") -> "gcnn.typing.GraphFunction":
"""Decorator that will print to the logger any differences in either the keys present in
the graph before and after the call, or any differences in their shapes.
This is super useful for diagnosing jax re-compilation issues.
"""
@functools.wraps(func)
def shape_checker(*args) -> jraph.GraphsTuple:
# Can either be a class method or a free function
inputs: jraph.GraphsTuple = args[0] if len(args) == 1 else args[1]
flattened, _ = tree_util.tree_flatten_with_path(inputs)
in_shapes = {path: array.shape for path, array in flattened}
out = func(*args)
out_shapes = {
(path, array.shape) for path, array in tree_util.tree_flatten_with_path(out)[0]
}
diff = out_shapes - set(in_shapes.items())
messages: list[str] = []
for path, shape in diff:
path_str = _tree.path_to_str(tuple(map(_tree.key_to_str, path)))
try:
in_shape = in_shapes[path]
except KeyError:
messages.append(f"new {path_str}")
else:
messages.append(f"{path_str} {in_shape}->{shape}")
if messages:
_LOGGER.debug(
"%s() difference(s) in inputs/outputs: %s",
func.__qualname__,
", ".join(messages),
)
return out
return shape_checker
class TransformedGraphFunction(Protocol):
"""Transformed graph function that returns a value or a tuple of a value and a graph"""
def __call__(
self, graph: jraph.GraphsTuple, *args: jt.PyTree, **kwargs: jt.PyTree
) -> jt.PyTree | tuple[jt.PyTree, jraph.GraphsTuple]: ...
[docs]
def adapt(
fn: "gcnn.typing.ExGraphFunction",
*args: "gcnn.TreePathLike",
outs: "Sequence[gcnn.TreePathLike]" = tuple(),
return_graphs: bool = False,
**keywords: "gcnn.TreePathLike",
) -> TransformedGraphFunction:
"""Given a graph function, this will return a function that takes a graph as the first argument
followed by positional arguments that will be mapped to the fields given by ``ins``.
Output paths can optionally be specified with ``outs`` which, if supplied, will make the
function return one or more values from the graph as returned by ``fn``.
Args:
fn: the graph function
*args: the input paths
outs: the output paths
return_graphs: if `True` and ``outs`` is specified, this will
return a tuple containing the output graph followed by the
values at ``outs``
Returns:
a function that wraps ``fn`` with the above properties
"""
args = tuple(_tree.path_from_str(path) for path in args)
kwarg_paths = {name: _tree.path_to_str(path) for name, path in keywords.items()}
outs = tuple(_tree.path_from_str(path) for path in outs)
def _fn(
graph: jraph.GraphsTuple, *arg_values: jt.PyTree, **kwarg_values
) -> jt.PyTree | tuple[jt.PyTree, jraph.GraphsTuple]:
# Set the values from graph at the correct paths in the graphs tuple
updater = exp_utils.update_graph(graph)
# Update from positional arguments first
for path, arg in zip(args, arg_values):
updater.set(path, arg)
# Now from kwargs
for name, value in kwarg_values.items():
updater.set(kwarg_paths[name], value)
graph = updater.get()
# Pass the graph through the original function
res = fn(graph, *arg_values[len(args) :])
if outs:
# Extract the quantity that we want as outputs
out_graph: jraph.GraphsTuple = res
out_vals = _tree.get(out_graph, *outs)
if return_graphs:
return out_vals, out_graph
return out_vals
if return_graphs:
# Just return the original input graph
return res, graph
return res
return _fn
transform_fn = adapt # For backward compatibility