from collections.abc import Callable, Mapping
import functools
import beartype
import e3nn_jax as e3j
from flax import linen
import jaxtyping as jt
from jaxtyping import Bool, Float, Int
import jraph
from tensorial.typing import Array, IndexArray, IntoIrreps, IrrepsArrayShape
from . import _base, _message_passing, keys
from .. import nn_utils
from .. import utils as tensorial_utils
__all__ = ("NequipLayer",)
# Default activations used by gate
DEFAULT_ACTIVATIONS = linen.FrozenDict({"e": "silu", "o": "tanh"})
ActivationLike = str | nn_utils.ActivationFunction
class InteractionBlock(linen.Module):
"""NequIP style interaction block.
Implementation based on
https://github.com/mir-group/nequip/blob/main/nequip/nn/_interaction_block.py
and
https://github.com/mariogeiger/nequip-jax/blob/main/nequip_jax/nequip.py
Args:
irreps_out: the irreps of the output node features
radial_num_layers: the number of layers in the radial MLP
radial_num_neurons: the number of neurons per layer in the
radial MLP
radial_activation: activation function used by radial MLP
avg_num_neighbours: average number of neighbours of each node,
used for normalization
skip_connection: If True, skip connection will be applied at end
of interaction
"""
irreps_out: IntoIrreps = 4 * e3j.Irreps("0e + 1o + 2e")
# Radial
radial_num_layers: int = 1
radial_num_neurons: int = 8
radial_activation: ActivationLike = "swish"
avg_num_neighbours: float | dict[int, float] = 1.0
skip_connection: bool = True
activations: str | Mapping[str, ActivationLike] = DEFAULT_ACTIVATIONS
num_species: int = 1
def setup(self):
# pylint: disable=attribute-defined-outside-init
self._message_passing = _message_passing.MessagePassingConvolution(
irreps_out=self.irreps_out,
avg_num_neighbours=self.avg_num_neighbours,
radial_num_layers=self.radial_num_layers,
radial_num_neurons=self.radial_num_neurons,
radial_activation=self.radial_activation,
)
self._gate = functools.partial(
e3j.gate,
even_act=nn_utils.get_jaxnn_activation(self.activations["e"]),
odd_act=nn_utils.get_jaxnn_activation(self.activations["o"]),
even_gate_act=nn_utils.get_jaxnn_activation(self.activations["e"]),
odd_gate_act=nn_utils.get_jaxnn_activation(self.activations["o"]),
)
self._radial_act = nn_utils.get_jaxnn_activation(self.radial_activation)
@linen.compact
@jt.jaxtyped(typechecker=beartype.beartype)
def __call__(
self,
node_features: IrrepsArrayShape["n_node irreps"],
edge_features: IrrepsArrayShape["n_edge edge_irreps"],
radial_embedding: Float[Array, "n_edge radial_embedding_dim"],
senders: IndexArray["n_edge"],
receivers: IndexArray["n_edge"],
node_species: Int[Array, "n_node"] | None = None,
*,
node_mask: Bool[Array, "n_node"] | None = None,
edge_mask: Bool[Array, "n_edge"] | None = None,
) -> e3j.IrrepsArray:
"""A NequIP interaction made up of the following steps:
- linear on nodes
- tensor product + aggregate
- divide by sqrt(average number of neighbors)
- concatenate
- linear on nodes
- gate non-linearity
"""
# The irreps to use for the output node features
output_irreps = e3j.Irreps(self.irreps_out).regroup()
if node_mask is not None:
node_mask = nn_utils.prepare_mask(node_mask, node_features)
if node_mask is not None:
node_features = e3j.where(
node_mask, node_features, tensorial_utils.zeros_like(node_features)
)
node_feats = e3j.flax.Linear(node_features.irreps, name="linear_up")(node_features)
if node_mask is not None:
node_features = e3j.where(
node_mask, node_features, tensorial_utils.zeros_like(node_features)
)
node_feats = self._message_passing(
node_feats, edge_features, radial_embedding, senders, receivers, edge_mask=edge_mask
)
gate_irreps = output_irreps.filter(keep=node_feats.irreps)
num_non_scalar = gate_irreps.filter(drop="0e + 0o").num_irreps
gate_irreps = gate_irreps + (num_non_scalar * e3j.Irrep("0e"))
# Second linear, now we create any extra gate scalars
node_feats = e3j.flax.Linear(gate_irreps, name="linear_down")(node_feats)
# self-connection: species weighted tensor product that maps to current irreps space
if self.skip_connection:
skip = e3j.flax.Linear(
node_feats.irreps,
num_indexed_weights=self.num_species,
name="skip_connection",
force_irreps_out=True,
)(node_species, node_features)
node_feats = 0.5 * (node_feats + skip)
# Apply non-linearity
node_feats = self._gate(node_feats)
return node_feats
[docs]
class NequipLayer(linen.Module):
"""NequIP convolution layer.
Implementation based on:
https://github.com/mir-group/nequip/blob/main/nequip/nn/_convnetlayer.py
"""
irreps_out: IntoIrreps
invariant_layers: int = 1
invariant_neurons: int = 8
# Radial
radial_num_layers: int = 1
radial_num_neurons: int = 8
radial_activation: ActivationLike = "swish"
avg_num_neighbours: float | dict[int, float] = 1.0
activations: str | Mapping[str, ActivationLike] = DEFAULT_ACTIVATIONS
node_features_field = keys.FEATURES
skip_connection: bool = True
num_species: int = 1
interaction_block: Callable = None
resnet: bool = False
[docs]
def setup(self):
# pylint: disable=attribute-defined-outside-init
if self.interaction_block is None:
self._interaction_block = InteractionBlock(
self.irreps_out,
radial_num_layers=self.radial_num_layers,
radial_num_neurons=self.radial_num_neurons,
radial_activation=self.radial_activation,
avg_num_neighbours=self.avg_num_neighbours,
skip_connection=self.skip_connection,
activations=self.activations,
num_species=self.num_species,
)
else:
self._interaction_block = self.interaction_block
@linen.compact
@jt.jaxtyped(typechecker=beartype.beartype)
@_base.shape_check
def __call__(
self, graph: jraph.GraphsTuple
) -> jraph.GraphsTuple: # pylint: disable=arguments-differ
"""Apply a standard NequIP layer followed by an optional resnet step
Args:
graph: the input graph
Returns:
the output graph with node features updated
"""
node_features = self._interaction_block(
graph.nodes[keys.FEATURES],
graph.edges[keys.ATTRIBUTES],
graph.edges[keys.RADIAL_EMBEDDINGS],
graph.senders,
graph.receivers,
graph.nodes.get(keys.SPECIES),
node_mask=graph.nodes.get(keys.MASK),
edge_mask=graph.edges.get(keys.MASK),
)
# If enabled, perform ResNet operation by adding back the old node features
if self.resnet:
node_features = node_features + graph.nodes[self.node_features_field]
# Update the graph
nodes = dict(graph.nodes)
nodes[keys.FEATURES] = node_features
return graph._replace(nodes=nodes)