Source code for tensorial.gcnn._nodewise

import logging
from typing import TYPE_CHECKING

import e3nn_jax as e3j
from flax import linen
import jax.numpy as jnp
import jraph

from . import _base, _tree, graph_ops, keys, utils
from .. import base
from .experimental import utils as exp_utils

if TYPE_CHECKING:
    import tensorial

__all__ = (
    "NodewiseLinear",
    "NodewiseReduce",
    "NodewiseEmbedding",
    "NodewiseEncoding",
    "NodewiseDecoding",
)

_LOGGER = logging.getLogger(__name__)


[docs] class NodewiseLinear(linen.Module): """Nodewise linear operation""" irreps_out: str | e3j.Irreps irreps_in: e3j.Irreps | None = None field: str = keys.FEATURES out_field: str | None = keys.FEATURES num_types: int | None = None types_field: str | None = None
[docs] def setup(self): # pylint: disable=attribute-defined-outside-init self.linear = e3j.flax.Linear( irreps_out=self.irreps_out, irreps_in=self.irreps_in, num_indexed_weights=self.num_types, force_irreps_out=True, ) # Set the default of the types field if num_types is supplied and the user didn't supply it if self.types_field is None: self._types_field = keys.SPECIES if self.num_types else None else: if not self.num_types: _LOGGER.warning( "User supplied a ``types_field``, %s, but failed to supply ``num_types``. " "Ignoring.", self._types_field, ) self._types_field = self.types_field
@_base.shape_check def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: nodes = graph.nodes if self.num_types: # We are using weights indexed by the type features = self.linear(nodes[self._types_field], nodes[self.field]) else: features = self.linear(nodes[self.field]) nodes[self.out_field] = features return graph._replace(nodes=nodes)
[docs] class NodewiseReduce(linen.Module): """Applies a reduction operation over node features and stores the result in the graph globals. This module reduces a specified field in the graph's node features across all nodes (within each graph if batched) using a specified reduction operation (`sum`, `mean`, or `normalized_sum`). The result is written to the `globals` field of the `GraphsTuple`. Attributes: field (str): Path to the node field to reduce, e.g. "energy" or "features.energy". out_field (Optional[str]): Path to the output global field. If None, defaults to "<reduce>_<field>" under `globals`. reduce (str): Reduction operation to apply. Must be one of "sum", "mean", or "normalized_sum". average_num_atoms (float): Required if `reduce` is "normalized_sum". Used to scale the result by `average_num_atoms ** -0.5`. Raises: ValueError: If `reduce` is not one of the allowed options. ValueError: If `reduce == "normalized_sum"` but `average_num_atoms` is not provided. Returns: A new `GraphsTuple` with the reduced value written to the specified `globals` field. """ field: str out_field: str | None = None reduce: str = "sum" average_num_atoms: float = None as_array: bool = False
[docs] def setup(self): # pylint: disable=attribute-defined-outside-init if self.reduce not in ("sum", "mean", "normalized_sum"): raise ValueError(self.reduce) self._field = ("nodes",) + utils.path_from_str( self.field if self.field is not None else tuple() ) self._out_field = ("globals",) + utils.path_from_str( self.out_field or f"{self.reduce}_{self.field}" ) if self.reduce == "normalized_sum": if self.average_num_atoms is None: raise ValueError(self.average_num_atoms) self.constant = float(self.average_num_atoms) ** -0.5 self._reduce = "sum" else: self.constant = 1.0 self._reduce = self.reduce
@_base.shape_check def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: reduced = self.constant * graph_ops.graph_segment_reduce(graph, self._field, self._reduce) if self.as_array and isinstance(reduced, e3j.IrrepsArray): reduced = reduced.array return exp_utils.update_graph(graph).set(self._out_field, reduced).get()
[docs] class NodewiseEmbedding(linen.Module): """Take the attributes in the nodes dictionary given by attrs, embed them, and store the results as a direct sum of irreps in the out_field. """ attrs: "tensorial.IrrepsTree" out_field: str = keys.ATTRIBUTES node_shape_from: str | None = keys.POSITIONS @_base.shape_check def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: if isinstance(self.attrs, (dict, linen.FrozenDict)): values = [] for key, attr in self.attrs.items(): path = _tree.path_from_str(key) if len(path) > 1: if not path[0] in ("nodes", "globals"): raise ValueError(f"The attribute key must not contain '.', got: {key}") else: # Assume it is a node attribute path = ("nodes",) + path try: value = _tree.get(graph, path) except KeyError: root = path[0] raise ValueError( f"Did not find '{_tree.path_to_str(path[1:])}' in " f"{root}, available entries are: " f"{', '.join(_tree.get(graph, root).keys())}" ) from None value: e3j.IrrepsArray = base.create_tensor(attr, value) if path[0] == "globals": if self.node_shape_from is None: # This will not have a static shape, so will cause recompilation total_repeat_length = jnp.sum(graph.n_node) else: total_repeat_length = graph.nodes[self.node_shape_from].shape[0] if len(value.shape) < 2: value = value.broadcast_to((total_repeat_length, *value.shape)) else: array = jnp.repeat( value.array, graph.n_node, axis=0, total_repeat_length=total_repeat_length, ) value = e3j.IrrepsArray(value.irreps, array) values.append(value) encoded: e3j.IrrepsArray = e3j.concatenate(values) else: values = graph.nodes # Create the embedding encoded: e3j.IrrepsArray = base.create_tensor(self.attrs, values) # Store in output field nodes = graph.nodes nodes[self.out_field] = encoded return graph._replace(nodes=nodes)
[docs] class NodewiseDecoding(linen.Module): """Decode the direct sum of irreps stored in the in_field and store each tensor as a node value with key coming from the attrs. """ attrs: "tensorial.IrrepsTree" in_field: str = keys.ATTRIBUTES @_base.shape_check def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: # Here, we need to split up the direct sum of irreps in the in field, and save the values # in the nodes dict corresponding to the attrs keys idx = 0 nodes_dict = graph.nodes irreps_tensor = nodes_dict[self.in_field] for key, value in base.tensorial_attrs(self.attrs).items(): irreps = base.irreps(value) tensor_slice = irreps_tensor[..., idx : idx + irreps.dim] nodes_dict[key] = base.from_tensor(value, tensor_slice) idx += irreps.dim # All done, return the new graph return graph._replace(nodes=nodes_dict)
# For legacy reasons NodewiseEncoding = NodewiseEmbedding