import logging
import numbers
from typing import Final
import e3nn_jax as e3j
import jax.numpy as jnp
from jaxtyping import Float, Num
import jraph
import numpy as np
import reax.metrics
from tensorial.typing import Array, CellType, PbcType
from . import keys
from .. import base, geometry, nn_utils
_LOGGER = logging.getLogger(__name__)
__all__ = ("graph_from_points", "with_edge_vectors")
[docs]
def graph_from_points(
pos: Float[Array, "n_nodes 3"],
r_max: numbers.Number,
*,
fractional_positions: bool = False,
self_interaction: bool = True,
strict_self_interaction: bool = False,
cell: CellType | None = None,
pbc: bool | PbcType | None = None,
nodes: dict[str, Num[Array, "n_nodes *"]] | None = None,
edges: dict | None = None,
graph_globals: dict[str, Array] | None = None,
np_=np,
) -> jraph.GraphsTuple:
"""Create a jraph Graph from a set of atomic positions and other related data.
Args:
pos: a [N, 3] array of atomic positions
r_max: the cutoff radius to use for identifying neighbours
fractional_positions: if ``True``, `pos` are interpreted as
fractional positions
self_interaction: if ``True``, edges are created between an atom
and itself in other unit cells
strict_self_interaction: if ``True``, edges are created between
an atom and itself within the central unit cell
cell: a [3, 3] array of unit cell vectors (in row-major format)
pbc: a ``bool`` of a sequence of three `bool`s indicating
whether the space is periodic in x, y, z directions
nodes: a dictionary containing additional data relating to each
node, it should contain arrays of shape [N, ...]
graph_globals: a dictionary containing additional global data
Returns:
the corresponding jraph Graph
"""
pos = np_.asarray(pos)
nodes = nodes if nodes else {}
num_nodes = len(pos)
nodes = nodes or {}
for name, value in nodes.items():
if value.shape[0] != num_nodes:
raise ValueError(
f"node attributes should have shape [N, ...], got {value.shape[0]} != {num_nodes} "
f"for {name}"
)
if pbc is None:
# there are no PBC if cell and pbc are not provided
if cell is not None:
raise ValueError("A cell was provided without PBCs")
pbc = False
if isinstance(pbc, bool):
pbc = (pbc,) * 3
if len(pbc) != 3:
raise ValueError(f"PBC must have length 3: {pbc}")
if fractional_positions:
if cell is None:
raise ValueError("Unit cell must be provided if fractional_positions is True")
# Use PBC to mask and only transform periodic dimensions from fractional (assuming
# non-periodic coordinates are already Cartesian)
pos[:, pbc] = (pos @ cell)[pbc]
neighbour_finder = geometry.neighbour_finder(
r_max,
cell,
pbc=pbc,
include_self=strict_self_interaction,
include_images=self_interaction,
)
neighbour_list = neighbour_finder.get_neighbours(pos)
from_idx, to_idx, cell_shifts = tuple(map(np_.array, neighbour_list.get_edges()))
nodes[keys.POSITIONS] = pos
graph_globals = graph_globals or {}
if pbc is not None:
graph_globals[keys.PBC] = np_.array(pbc, dtype=bool)
if cell is not None:
graph_globals[keys.CELL] = cell
# We have to pad out the globals to make things like batching work
graph_globals = {
key: np_.expand_dims(base.atleast_1d(value), 0)
for key, value in graph_globals.items()
if value is not None
}
edges = edges or {}
edges = {key: value[from_idx, to_idx] for key, value in edges.items()}
# Make sure the edge arrays have array-like (rather than scalar) entries for each edge
edges = {
key: np_.expand_dims(value, -1) if value.ndim == 1 else value
for key, value in edges.items()
}
if cell is not None:
edges[keys.EDGE_CELL_SHIFTS] = cell_shifts
return jraph.GraphsTuple(
nodes=nodes,
edges=edges,
senders=from_idx,
receivers=to_idx,
globals=graph_globals,
n_node=np_.array([len(pos)]),
n_edge=np_.array([len(from_idx)]),
)
[docs]
def with_edge_vectors(
graph: jraph.GraphsTuple,
with_lengths: bool = True,
as_irreps_array: bool | None = True,
) -> jraph.GraphsTuple:
"""Compute edge displacements for edge vectors in a graph.
This will add edge attributes corresponding that cache the vectors and displacements, meaning
that they will not be recalculated if already done so.
"""
edges = graph.edges
pos = graph.nodes[keys.POSITIONS]
edge_vecs = pos[graph.receivers] - pos[graph.senders]
if keys.CELL in graph.globals:
cell = graph.globals[keys.CELL]
cell_shifts = edges[keys.EDGE_CELL_SHIFTS]
shift_vectors = jnp.einsum(
"ni,nij->nj",
cell_shifts,
jnp.repeat(cell, graph.n_edge, axis=0, total_repeat_length=edge_vecs.shape[0]),
)
edge_vecs = edge_vecs + shift_vectors
edge_mask = graph.edges.get(keys.MASK)
if edge_mask is not None:
edge_mask = reax.metrics.utils.prepare_mask(edge_vecs, edge_mask)
edge_vecs = jnp.where(edge_mask, edge_vecs, 1.0)
if not isinstance(edge_vecs, e3j.IrrepsArray) and as_irreps_array:
edge_vecs = e3j.IrrepsArray("1o", edge_vecs)
edges[keys.EDGE_VECTORS] = edge_vecs
# To allow grad to work, we need to mask off the padded edge-vectors that are zero, see:
# * https://github.com/google/jax/issues/6484,
# * https://stackoverflow.com/q/74864427/1257417
if with_lengths:
lengths = jnp.expand_dims(jnp.linalg.norm(base.as_array(edge_vecs), axis=-1), -1)
if edge_mask is not None:
lengths = jnp.where(edge_mask, lengths, 0.0)
if as_irreps_array:
lengths = e3j.as_irreps_array(lengths)
edges[keys.EDGE_LENGTHS] = lengths
graph = graph._replace(edges=edges)
return graph
def _pairwise_sq_distances(pos: Array, mask: Array | None) -> Array:
# x_shape: (num_nodes_in_graph, 3)
if mask is not None:
mask = nn_utils.prepare_mask(mask, pos)
pos = jnp.where(mask, pos, jnp.inf)
diff = pos[:, None, :] - pos[None, :, :]
return (diff**2).sum(axis=2)
def _update_positions(
padded_graph: jraph.GraphsTuple, new_pos: Array, r_max: float
) -> jraph.GraphsTuple:
nodes_dict = padded_graph.nodes
if not nodes_dict[keys.POSITIONS].shape == new_pos.shape:
raise ValueError(
f"New positions ({new_pos.shape}) and current position arrays "
f"({nodes_dict[keys.POSITIONS].shape}) do not have the same shape"
)
nodes_dict["positions"] = new_pos
dists_mtx = _pairwise_sq_distances(new_pos, nodes_dict[keys.MASK])
r_max_sq: Final = r_max**2
pairs = jnp.argwhere(
(dists_mtx < r_max_sq) & (dists_mtx > 0.0),
size=padded_graph.senders.shape[0],
fill_value=-1, # Use -1 to identify the padding values
)
n_edge: Final = padded_graph.edges[keys.MASK].shape[0]
new_edge_mask = pairs[:, 0] >= 0 # i.e. not -1
new_n_edge = sum(new_edge_mask)
new_senders = jnp.where(new_edge_mask, pairs[:, 0], 0.0)
new_receivers = jnp.where(new_edge_mask, pairs[:, 1], 0.0)
return padded_graph._replace(
senders=new_senders,
receivers=new_receivers,
n_edge=jnp.array([new_n_edge, n_edge - new_n_edge]),
edges={"mask": new_edge_mask},
nodes=nodes_dict,
)