Source code for tensorial.gcnn._modules

from collections.abc import Hashable, Sequence
import logging
from typing import TYPE_CHECKING

import e3nn_jax as e3j
from flax import linen
import jax
import jax.lax
import jax.numpy as jnp
import jraph
from pytray import tree

from . import _base, utils
from .experimental import utils as exp_utils

if TYPE_CHECKING:
    from tensorial import gcnn

_LOGGER = logging.getLogger(__name__)

__all__ = "Rescale", "IndexedLinear", "IndexedRescale"


[docs] class Rescale(linen.Module): """Applies constant rescaling and/or shifting to fields in a `jraph.GraphsTuple`. This module modifies specified fields in the graph — which may be located in the `nodes`, `edges`, or `globals` — by multiplying them with a scalar factor (`scale`) and/or adding a constant offset (`shift`). This is useful for normalizing or denormalizing values, or applying consistent physical unit conversions. Both `scale_fields` and `shift_fields` may be either a single string (e.g. "nodes.energy") or a sequence of path strings. Missing fields are ignored silently. Example usage: -------------- >>> Rescale(shift_fields='nodes.energy', shift=12.5) shifts the energy stored in each node by 12.5. >>> Rescale(scale_fields=['globals.volume'], scale=1e-3) rescales the global volume by 1e-3. Attributes: ----------- shift_fields : str | Sequence[Hashable] Path(s) to the fields to which a constant shift should be applied. scale_fields : str | Sequence[Hashable] Path(s) to the fields to which a constant scale should be applied. shift : jax.Array Scalar constant to be added to all values in `shift_fields`. Defaults to 0.0. scale : jax.Array Scalar constant to multiply all values in `scale_fields`. Defaults to 1.0. Notes: ------ - Fields that are not found in the graph are skipped silently. - If a global field is shifted, a warning is logged that the field will no longer be size extensive with respect to the number of nodes or edges. """ shift_fields: str | Sequence[Hashable] = tuple() scale_fields: str | Sequence[Hashable] = tuple() shift: jax.typing.ArrayLike = 0.0 scale: jax.typing.ArrayLike = 1.0
[docs] def setup(self): # pylint: disable=attribute-defined-outside-init shift_fields = ( self.shift_fields if not isinstance(self.shift_fields, str) else [self.shift_fields] ) scale_fields = ( self.scale_fields if not isinstance(self.scale_fields, str) else [self.scale_fields] ) self._shift_fields = tuple(map(utils.path_from_str, shift_fields)) self._scale_fields = tuple(map(utils.path_from_str, scale_fields)) if self.shift != 0.0: for path in self._shift_fields: if path[0] == "globals": _LOGGER.info( "Setting shift `%s` to `%f`, this field will no longer be size " "extensive with the number of nodes/edges", utils.path_to_str(path), self.shift, )
@linen.compact @_base.shape_check def __call__( self, graph: jraph.GraphsTuple ) -> jraph.GraphsTuple: # pylint: disable=arguments-differ graph_dict = utils.UpdateDict(graph._asdict()) # Scale first for field in self._scale_fields: try: new_value = tree.get_by_path(graph_dict, field) * self.scale tree.set_by_path(graph_dict, field, new_value) except KeyError: pass # Ignore missing keys # Now shift for field in self._shift_fields: try: new_value = tree.get_by_path(graph_dict, field) + self.shift tree.set_by_path(graph_dict, field, new_value) except KeyError: pass # Ignore missing keys return jraph.GraphsTuple(**graph_dict._asdict())
[docs] class IndexedRescale(linen.Module): """Applies a per-type affine transformation (scale and shift) to a specified field in a graph. Each input is scaled and shifted based on an associated index (e.g. atomic or node type). The transformation is of the form: `output = input * scale + shift`, where both `scale` and `shift` are either learnable parameters or provided constants, indexed by the value in `index_field`. This is typically used to normalize or denormalize features like node energies, depending on the type of node or atom. Attributes: num_types (int): Number of unique types (i.e. distinct values in `index_field`). Determines the number of learnable `scale` and `shift` parameters. index_field (str): Path (e.g. "nodes.type") to the array of indices used to select the scale and shift for each input. field (str): Path to the input field to be rescaled. out_field (Optional[str]): Path to the output field. If `None`, the result is written to `field`. shifts (Optional[ArrayLike]): Optional constant shift values of shape `(num_types,)`. If `None`, the shifts are learned parameters initialized with `shift_init`. scales (Optional[ArrayLike]): Optional constant scale values of shape `(num_types, 1)`. If `None`, the scales are learned parameters initialized with `rescale_init`. rescale_init (Initializer): Initializer for learnable scale parameters. shift_init (Initializer): Initializer for learnable shift parameters. Returns: jraph.GraphsTuple: A new graph with the specified field transformed and stored at `out_field`. Raises: ValueError: If the number of types does not match the shape of provided `scales` or `shifts`. Notes: - Supports `e3nn_jax.IrrepsArray` input and preserves irreps metadata. - Uses `jax.vmap` internally for efficiency across nodes. """ num_types: int index_field: str field: str out_field: str | None = None shifts: jax.typing.ArrayLike | None = None scales: jax.typing.ArrayLike | None = None rescale_init: linen.initializers.Initializer = linen.initializers.lecun_normal() shift_init: linen.initializers.Initializer = linen.initializers.zeros_init()
[docs] def setup(self): # pylint: disable=attribute-defined-outside-init self._index_field = utils.path_from_str(self.index_field) self._field = utils.path_from_str(self.field) self._out_field = ( self._field if self.out_field is None else utils.path_from_str(self.out_field) ) self._scales = ( self.param( "scales", self.rescale_init, (self.num_types, 1), ) if self.scales is None else self._to_array(self.scales, self.num_types) ) self._shifts = ( self.param("shifts", self.shift_init, (self.num_types,)) if self.shifts is None else self._to_array(self.shifts, self.num_types) )
# assert self._scales.shape == self._shifts.shape @_base.shape_check def __call__(self, graph: jraph.GraphsTuple): graph_dict: dict = graph._asdict() # Get the indexes and values indexes = tree.get_by_path(graph_dict, self._index_field) inputs = tree.get_by_path(graph_dict, self._field) if isinstance(inputs, e3j.IrrepsArray): output_irreps = inputs.irreps inputs = inputs.array else: output_irreps = None # Get the shifts and scales scales = jnp.take(self._scales, indexes) shifts = jnp.take(self._shifts, indexes) outs = jax.vmap(lambda inp, scale, shift: inp * scale + shift, (0, 0, 0))( inputs, scales, shifts ) if output_irreps is not None: outs = e3j.IrrepsArray(output_irreps, outs) return exp_utils.update_graph(graph).set(self._out_field, outs).get() @staticmethod def _to_array(value, num_types): return value if isinstance(value, jax.Array) else jnp.array([value] * num_types)
[docs] class IndexedLinear(linen.Module): """Applies an indexed linear transformation to a field in a `GraphsTuple`. This module performs a linear transformation on a per-element basis, where each element is routed through a specific linear layer determined by an associated index array. A separate set of learnable weights is maintained for each index value. Attributes: irreps_out (str | e3j.Irreps): The output irreducible representations of the linear transformation. num_types (int): Number of distinct index values, corresponding to the number of weight sets. index_field (str): Dot-separated path to the index array within the `GraphsTuple`. field (str): Dot-separated path to the input features within the `GraphsTuple`. out_field (Optional[str]): Dot-separated path where output features should be written. If None, overwrites `field`. name (str): Optional name for the internal Linear module. Args: graph (jraph.GraphsTuple): A graph with fields specified by `index_field` and `field`. Returns: jraph.GraphsTuple: A new graph with updated features at `out_field`, where each input vector has been transformed by a linear layer corresponding to its associated index. Raises: KeyError: If the specified `field` or `index_field` does not exist in the graph. ValueError: If the index values exceed the range `[0, num_types - 1]`. Example: If `graph.nodes` contains input features and `graph.nodes["type"]` contains integer indices in `[0, num_types)`, the module applies a learned linear map per type: IndexedLinear("64x0e", num_types=5, index_field="nodes.type", field="nodes.feat") Each node's "feat" will be transformed by a different `Linear` layer according to its "type". """ irreps_out: str | e3j.Irreps num_types: int index_field: str field: str out_field: str | None = None name: str = None @linen.compact @_base.shape_check def __call__( self, graph: jraph.GraphsTuple ) -> jraph.GraphsTuple: # pylint: disable=arguments-differ index_field = utils.path_from_str(self.index_field) field: "gcnn.TreePath" = utils.path_from_str(self.field) out_field = field if self.out_field is None else utils.path_from_str(self.out_field) linear = e3j.flax.Linear( self.irreps_out, num_indexed_weights=self.num_types, name=self.name, force_irreps_out=True, ) graph_dict = graph._asdict() # Get the indexes and values indexes = tree.get_by_path(graph_dict, index_field) inputs = tree.get_by_path(graph_dict, field) # Call the branches and update the graph outs = linear(indexes, inputs) return exp_utils.update_graph(graph).set(out_field, outs).get()