Source code for tensorial.gcnn._tree
from collections.abc import Sequence
import functools
from typing import TYPE_CHECKING, Final
import jax
import jraph
from pytray import tree
from . import keys
if TYPE_CHECKING:
from tensorial import gcnn
DEFAULT_DELIMITER: Final[str] = "."
@functools.singledispatch
def key_to_str(key) -> str:
raise ValueError(key)
@key_to_str.register
def attr_key_to_str(key: jax.tree_util.GetAttrKey) -> str:
return key.name
@key_to_str.register
def dict_key_to_str(key: jax.tree_util.DictKey) -> str:
return str(key.key)
@key_to_str.register
def sequence_key_to_str(key: jax.tree_util.SequenceKey) -> str:
return str(key.idx)
@key_to_str.register
def indexed_key_to_str(key: jax.tree_util.FlattenedIndexKey) -> str:
return str(key.key)
[docs]
def path_from_str(
path_str: "gcnn.typing.TreePathLike", delimiter: str = DEFAULT_DELIMITER
) -> "gcnn.typing.TreePath":
"""Split up a path string into a tuple of path components"""
if isinstance(path_str, tuple):
return path_str
if path_str == "":
return tuple()
return tuple(path_str.split(delimiter))
[docs]
def path_to_str(path: "gcnn.typing.TreePathLike", delimiter: str = DEFAULT_DELIMITER) -> str:
"""Return a string representation of a tree path"""
if isinstance(path, str):
return path
return delimiter.join(path)
def get(
graph: jraph.GraphsTuple, *path: "gcnn.typing.TreePathLike"
) -> jax.Array | tuple[jax.Array, ...]:
"""Given a graph, this will extract the values as the passed path(s) and return them directly
Args:
graph: the graph to get values from
*path: the path(s)
Returns:
the values at those paths
"""
path = tuple(map(path_from_str, path))
graph_dict = graph._asdict()
vals = tuple(map(functools.partial(tree.get_by_path, graph_dict), path))
if len(path) == 1:
return vals[0]
return vals
def to_paths(
wrt: str | Sequence["gcnn.typing.TreePathLike"] | None,
) -> "tuple[gcnn.typing.TreePath, ...]":
if wrt is None:
return tuple()
if isinstance(wrt, str):
return (path_from_str(wrt),)
if isinstance(wrt, Sequence):
return tuple(map(path_from_str, wrt))
raise ValueError(f"wrt must be str or list or tuple thereof, got {type(wrt).__name__}")
def path_root(
path: "gcnn.typing.TreePathLike",
delimiter=DEFAULT_DELIMITER,
) -> "gcnn.typing.TreePath":
return path_from_str(path, delimiter=delimiter)[:1]
def get_mask(
graph: jraph.GraphsTuple,
path: "gcnn.typing.TreePathLike",
delimiter=DEFAULT_DELIMITER,
) -> jax.Array | None:
path = path_root(path, delimiter) + (keys.MASK,)
try:
return tree.get_by_path(graph._asdict(), path)
except KeyError:
return None