Source code for tensorial.gcnn._edgewise
from collections.abc import Callable
import functools
from typing import TYPE_CHECKING
import e3nn_jax as e3j
from flax import linen
import jax
import jraph
from . import _base, _spatial, keys
from .. import base
if TYPE_CHECKING:
import tensorial
__all__ = (
"EdgewiseLinear",
"EdgewiseDecoding",
"EdgewiseEmbedding",
"EdgewiseEncoding",
"RadialBasisEdgeEmbedding",
"RadialBasisEdgeEncoding",
"EdgeVectors",
)
[docs]
class EdgewiseLinear(linen.Module):
"""Edgewise linear operation"""
irreps_out: str | e3j.Irreps
irreps_in: e3j.Irreps | None = None
field: str = keys.FEATURES
out_field: str | None = keys.FEATURES
[docs]
def setup(self):
# pylint: disable=attribute-defined-outside-init
self.linear = e3j.flax.Linear(
irreps_out=self.irreps_out,
irreps_in=self.irreps_in,
force_irreps_out=True,
)
@_base.shape_check
def __call__(self, graph: jraph.GraphsTuple):
edges = graph.edges
edges[self.out_field] = self.linear(edges[self.field])
return graph._replace(edges=edges)
[docs]
class EdgewiseEmbedding(linen.Module):
attrs: "tensorial.IrrepsTree"
out_field: str = keys.ATTRIBUTES
@_base.shape_check
def __call__(
self, graph: jraph.GraphsTuple
) -> jraph.GraphsTuple: # pylint: disable=arguments-differ
# Create the encoding
encoded = base.create_tensor(self.attrs, graph.edges)
# Store in output field
edges = graph.edges
edges[self.out_field] = encoded
return graph._replace(edges=edges)
[docs]
class EdgewiseDecoding(linen.Module):
"""Decode the direct sum of irreps stored in the in_field and store each tensor as a node value
with key coming from the attrs.
"""
attrs: "tensorial.IrrepsTree"
in_field: str = keys.ATTRIBUTES
@_base.shape_check
def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
# Here, we need to split up the direct sum of irreps in the in field, and save the values
# in the edges dict corresponding to the attrs keys
idx = 0
edges_dict = graph.edges
irreps_tensor = edges_dict[self.in_field]
for key, value in base.tensorial_attrs(self.attrs).items():
irreps = base.irreps(value)
tensor_slice = irreps_tensor[..., idx : idx + irreps.dim]
edges_dict[key] = base.from_tensor(value, tensor_slice)
idx += irreps.dim
# All done, return the new graph
return graph._replace(edges=edges_dict)
[docs]
class RadialBasisEdgeEmbedding(linen.Module):
field: str = keys.EDGE_LENGTHS
out_field: str = keys.RADIAL_EMBEDDINGS
num_basis: int = 8
r_max: float = 4.0
envelope: bool | Callable = False
[docs]
def setup(self):
# pylint: disable=attribute-defined-outside-init
self.radial_embedding = functools.partial( # pylint: disable=attribute-defined-outside-init
e3j.bessel,
x_max=self.r_max,
n=self.num_basis,
)
self._envelope = self._init_envelope(self.envelope)
@staticmethod
def _init_envelope(envelope) -> Callable | None:
if envelope:
return envelope if callable(envelope) else e3j.poly_envelope(1, 1)
return None
@_base.shape_check
def __call__(
self, graph: jraph.GraphsTuple
) -> jraph.GraphsTuple: # pylint: disable=arguments-differ
edge_dict = _spatial.with_edge_vectors(graph).edges
r = base.as_array(edge_dict[keys.EDGE_LENGTHS])[:, 0]
embedded = self.radial_embedding(r)
if self._envelope is not None:
embedded = jax.vmap(self._envelope)(r / self.r_max)[..., None] * embedded
edge_dict[self.out_field] = embedded
return graph._replace(edges=edge_dict)
[docs]
class EdgeVectors(linen.Module):
"""Create edge vectors from atomic positions. This will take into account the unit cell
(if present)
"""
as_irreps_arrays: bool = False
@linen.compact
@_base.shape_check
def __call__(
self, graph: jraph.GraphsTuple
) -> jraph.GraphsTuple: # pylint: disable=arguments-differ
return _spatial.with_edge_vectors(
graph, with_lengths=True, as_irreps_array=self.as_irreps_arrays
)
# For legacy reasons
EdgewiseEncoding = EdgewiseEmbedding
RadialBasisEdgeEncoding = RadialBasisEdgeEmbedding