Source code for tensorial.geometry.jax_neighbours

import functools
import numbers

import beartype
import equinox
import jax
import jax.numpy as jnp
import jaxtyping as jt

from tensorial.typing import CellType, PbcType

from . import distances, unit_cells

i32 = jnp.int32  # pylint: disable=invalid-name

DEFAULT_MAX_CELL_MULTIPLES = 10_000
MASK_VALUE = -1


[docs] class NeighbourList(equinox.Module, distances.NeighbourList): neighbours: jax.Array cell_indices: jax.Array actual_max_neighbours: int _finder: "NeighbourFinder" def __init__( self, neighbours: jt.ArrayLike, cell_indices: jt.ArrayLike, actual_max_neighbours: jax.Array = -1, finder: "NeighbourFinder" = None, ): if neighbours.shape != cell_indices.shape[:2]: raise ValueError("Cell indices and neighbours must have same shape") # checkify.check(neighbours.shape == cell_indices.shape[:2], "Cell indices and neighbours # must have same shape") # if jnp.any(neighbours > neighbours.shape[0]): # raise ValueError( # "One or more entries in the neighbours array refers to an index higher than the # maximum possible") self.neighbours = jnp.asarray(neighbours) self.cell_indices = jnp.asarray(cell_indices) self.actual_max_neighbours = actual_max_neighbours self._finder = finder @property def num_particles(self) -> int: return self.neighbours.shape[0] @property def max_neighbours(self) -> int: return self.neighbours.shape[1] @property def did_overflow(self) -> bool: """Returns `True` if the list could not accommodate all the neighbours. The actual number needed is stored in `actual_max_neighbours` """ return self.actual_max_neighbours > self.max_neighbours
[docs] def get_edges(self) -> distances.Edges: mask = self.neighbours != MASK_VALUE from_idx = jnp.repeat( jnp.arange(0, self.num_particles)[:, None], self.max_neighbours, axis=1 ) return distances.Edges(from_idx[mask], self.neighbours[mask], self.cell_indices[mask])
[docs] def list_overflow(self) -> bool: return self.actual_max_neighbours > self.max_neighbours
[docs] def reallocate(self, positions: jt.ArrayLike) -> "NeighbourList": return self._finder.get_neighbours(positions, max_neighbours=self.actual_max_neighbours)
[docs] class NeighbourFinder(equinox.Module, distances.NeighbourFinder):
[docs] def get_neighbours(self, positions: jt.ArrayLike, max_neighbours: int = None) -> NeighbourList: """Get the neighbour list for the given positions"""
[docs] def estimate_neighbours(self, positions: jt.ArrayLike) -> int: """Estimate the number of neighbours per particle"""
[docs] class OpenBoundary(NeighbourFinder): _cutoff: float _include_self: bool def __init__(self, cutoff: numbers.Number, include_self=False): self._cutoff = float(cutoff) self._include_self = include_self
[docs] def get_neighbours(self, positions: jt.ArrayLike, max_neighbours: int = None) -> NeighbourList: positions = jnp.asarray(positions) num_points = positions.shape[0] max_neighbours = max_neighbours or self.estimate_neighbours(positions) # Get the neighbours mask neigh_mask = jax.vmap(neighbours_mask_direct, (0, None, None))( positions, positions, self._cutoff ) if not self._include_self: neigh_mask &= ~jnp.eye(num_points, dtype=bool) get_neighbours = functools.partial(jnp.argwhere, size=max_neighbours, fill_value=-1) to_idx = jax.vmap(get_neighbours)(neigh_mask)[..., 0] cell_indices = jnp.zeros((*to_idx.shape, 3), dtype=int) return NeighbourList( to_idx, cell_indices, actual_max_neighbours=jnp.max(neigh_mask.sum(axis=1)), finder=self, )
[docs] def estimate_neighbours(self, positions: jt.ArrayLike) -> int: positions = jnp.asarray(positions) dimensions = jnp.max(positions, axis=0) - jnp.min(positions, axis=0) # Clamp the minimum otherwise we might get a div by zero dimensions = jnp.where(dimensions == 0.0, 1.0, dimensions) approx_density = positions.shape[0] / jnp.prod(dimensions) return int(3 * jnp.ceil(approx_density * unit_cells.sphere_volume(self._cutoff)).item())
[docs] class PeriodicBoundary(NeighbourFinder): _cell: jax.Array _cutoff: float _cell_list: jax.Array _grid_points: jax.Array _include_self: bool _include_images: bool _self_cell: int def __init__( self, cell: CellType, cutoff: numbers.Number, pbc: PbcType | None = None, *, max_cell_multiples: int = DEFAULT_MAX_CELL_MULTIPLES, include_self=False, include_images=True, ): self._cell = jnp.asarray(cell) self._cutoff = float(cutoff) self._cell_list, self._grid_points = get_cell_list( self._cell, cutoff, pbc, max_cell_multiples=max_cell_multiples ) self._self_cell = jnp.argwhere( jax.vmap(jnp.array_equal, (0, None))(self._cell_list, jnp.zeros(3, dtype=i32)) )[0, 0].item() self._include_self = include_self self._include_images = include_images
[docs] def get_neighbours(self, positions: jt.ArrayLike, max_neighbours: int = None) -> NeighbourList: num_points = positions.shape[0] num_cells = self._cell_list.shape[0] max_neighbours = ( max_neighbours if max_neighbours is not None else self.estimate_neighbours(positions) ) neighbours = jax.vmap(lambda shift: shift + positions)(self._grid_points).reshape(-1, 3) # Get the neighbours mask neigh_mask = jax.vmap(neighbours_mask_direct, (0, None, None))( positions, neighbours, self._cutoff ) if not self._include_self or not self._include_images: neigh_mask2 = neigh_mask.reshape(num_points, num_cells, num_points) mask = ~jnp.eye(num_points, dtype=bool) if not self._include_images: neigh_mask2 = neigh_mask2 & mask if not self._include_self: neigh_mask2 = neigh_mask2.at[:, self._self_cell, :].set( neigh_mask2[:, self._self_cell, :] & mask ) neigh_mask = neigh_mask2.reshape(num_points, num_cells * num_points) get_neighbours = functools.partial(jnp.argwhere, size=max_neighbours, fill_value=MASK_VALUE) to_idx = jax.vmap(get_neighbours)(neigh_mask)[..., 0] # Repeat the cells for each cells = jnp.repeat(self._cell_list, num_points, axis=0) cell_indices = jax.vmap(jnp.take, (None, 0, None))(cells, to_idx, 0) return NeighbourList( jnp.where(to_idx == MASK_VALUE, MASK_VALUE, to_idx % num_points), cell_indices, actual_max_neighbours=jnp.max(neigh_mask.sum(axis=1)), finder=self, )
[docs] def estimate_neighbours(self, positions: jt.ArrayLike) -> int: density = positions.shape[0] / unit_cells.cell_volume(self._cell) return int(1.3 * jnp.ceil(density * unit_cells.sphere_volume(self._cutoff) + 1.0).item())
[docs] @jt.jaxtyped(typechecker=beartype.beartype) def neighbour_finder( cutoff: numbers.Number, cell: CellType | None = None, pbc: PbcType | None = None, include_self: bool = False, **kwargs, ) -> NeighbourFinder: if pbc is not None and any(pbc): return PeriodicBoundary(cell, cutoff, pbc, include_self=include_self, **kwargs) return OpenBoundary(cutoff, include_self=include_self)
[docs] def generate_positions(cell: jax.Array, positions: jax.Array, cell_shifts: jax.Array) -> jax.Array: return jax.vmap(lambda shift: (shift @ cell) + positions)(cell_shifts)
[docs] def get_cell_list( cell: CellType, cutoff: numbers.Number, pbc: PbcType | None = (True, True, True), max_cell_multiples: int = DEFAULT_MAX_CELL_MULTIPLES, ) -> tuple[jax.Array, jax.Array]: cell = jnp.asarray(cell) # Get the multipliers for each cell direction cell_ranges = unit_cells.get_cell_multiple_ranges(cell, cutoff=cutoff, pbc=pbc) # Clamp the cell range cell_ranges = tuple( (max(nmin, -max_cell_multiples), min(nmax, max_cell_multiples)) for nmin, nmax in cell_ranges ) cell_grid = jnp.array( jnp.meshgrid( jnp.arange(*cell_ranges[0]), jnp.arange(*cell_ranges[1]), jnp.arange(*cell_ranges[2]), indexing="ij", ) ) reshaped = cell_grid.T.reshape(-1, 3) grid_points = reshaped @ cell # corners = jnp.array(list(itertools.product((0, 1), repeat=3)), dtype=i32) # corners = corners @ cell # mask = jax.vmap(neighbours_mask_direct, (0, None, None))( # corners, grid_points, cutoff).any(axis=0) # return reshaped[mask], grid_points[mask] return reshaped, grid_points
[docs] def neighbours_mask_aabb( centre: jt.ArrayLike, neighbours: jt.ArrayLike, cutoff: float ) -> jax.Array: """Get the indices of all points that are within a cutoff sphere centred on `centre` with a radius `cutoff` using the Axis Aligned Bounding Box method """ diag = cutoff / jnp.sqrt(3.0) centred = neighbours - centre # First find those that fit into the axis aligned bounding box that fits within the cutoff # sphere definitely_neighbour = jnp.array( jnp.all((-diag < centred) & (centred < diag), axis=1), dtype=bool ) maybe_neighbour = jnp.all(-cutoff < centred & centred < cutoff, axis=1) & ~definitely_neighbour # Now check the remaining ones that lie within the shell between the AABB that fits within the # sphere and the AABB that bounds the sphere maybe_norm_sq = jnp.sum(centred[maybe_neighbour] ** 2, axis=1) return definitely_neighbour.at[maybe_neighbour].set( definitely_neighbour[maybe_neighbour] | (maybe_norm_sq < (cutoff * cutoff)) )
[docs] def neighbours_mask_direct( centre: jt.ArrayLike, neighbours: jt.ArrayLike, cutoff: float ) -> jax.Array: """Get the indices of all points that are within a cutoff sphere centred on ``centre`` with a radius ``cutoff`` by calculating all distance vector norms and masking those within the cutoff """ centred = neighbours - centre return jnp.array(jnp.sum(centred**2, axis=1) <= (cutoff * cutoff), dtype=bool)