Source code for tensorial.gcnn.experimental.utils
import copy
from typing import Any
import jraph
__all__ = "update_graph", "GraphMutator"
[docs]
class GraphMutator:
def __init__(self, graph: jraph.GraphsTuple):
self.original = graph
self.mutations: list[tuple[str, tuple[str | int, ...], Any]] = []
def _normalize_path(self, path: str | tuple) -> tuple:
return tuple(path.split(".")) if isinstance(path, str) else path
[docs]
def set(self, path: str | tuple, value) -> "GraphMutator":
self.mutations.append(("set", self._normalize_path(path), value))
return self
[docs]
def update(self, path: str | tuple, updates: dict) -> "GraphMutator":
self.mutations.append(("update", self._normalize_path(path), updates))
return self
[docs]
def delete(self, path: str | tuple) -> "GraphMutator":
self.mutations.append(("delete", self._normalize_path(path), None))
return self
def _apply_mutation(self, container: Any, path: tuple, op: str, value: Any):
if not path:
if op == "set":
return value
if op == "update":
if not isinstance(container, dict):
raise TypeError("Can only apply 'update' to a dict at root level.")
container.update(value)
return container
if op == "delete":
raise ValueError("Cannot delete the root container.")
raise ValueError(f"Unsupported op '{op}' at root.")
*parents, last = path
target = container if container is not None else {}
for key in parents:
target = target[key]
if op == "set":
target[last] = value
elif op == "update":
if not isinstance(target[last], dict):
raise TypeError(f"Cannot update non-dict object at path: {'.'.join(path)}")
target[last].update(value)
elif op == "delete":
del target[last]
return target if container is None else None
[docs]
def get(self) -> jraph.GraphsTuple:
# Deepcopy all mutable fields; assume non-dict fields are safe to mutate directly
mutated_fields = {
# In gcnn these can be (mutable) dicts of immutable values
"nodes": copy.copy(self.original.nodes),
"edges": copy.copy(self.original.edges),
"globals": copy.copy(self.original.globals),
# while the following are simply immutable arrays
"n_node": self.original.n_node,
"n_edge": self.original.n_edge,
"senders": self.original.senders,
"receivers": self.original.receivers,
}
for op, path, value in self.mutations:
root = path[0]
if root not in mutated_fields:
raise ValueError(
f"GraphsTuple does not have an attribute '{root}', "
f"must be one of: {' '.join(mutated_fields.keys())}"
)
result = self._apply_mutation(mutated_fields[root], path[1:], op, value)
if result is not None:
mutated_fields[root] = result
return jraph.GraphsTuple(**mutated_fields)
[docs]
def update_graph(graph: jraph.GraphsTuple) -> GraphMutator:
return GraphMutator(graph)