Source code for tensorial.gcnn.atomic._modules

from collections.abc import Sequence

import beartype
import equinox
import jax
import jax.numpy as jnp
import jaxtyping as jt
from jaxtyping import Int
import jraph

from tensorial.typing import Array

from . import keys
from .. import _modules as gcnn_modules
from .. import keys as gcnn_keys
from ... import nn_utils

__all__ = "SpeciesTransform", "per_species_rescale"


[docs] @jt.jaxtyped(typechecker=beartype.beartype) class SpeciesTransform(equinox.Module): """Take an ordered list of species and transform them into an integer corresponding to their position in the list """ atomic_numbers: Int[jax.Array, "numbers"] field: str = keys.ATOMIC_NUMBERS out_field: str = gcnn_keys.SPECIES @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, atomic_numbers: Sequence[int] | Int[Array, "numbers"], field: str = keys.ATOMIC_NUMBERS, out_field: str = gcnn_keys.SPECIES, ): self.atomic_numbers = jnp.asarray(atomic_numbers) self.field = field self.out_field = out_field @jt.jaxtyped(typechecker=beartype.beartype) def __call__( self, graph: jraph.GraphsTuple ) -> jraph.GraphsTuple: # pylint: disable=arguments-differ nodes = graph.nodes nodes[self.out_field] = nn_utils.vwhere(nodes[self.field], self.atomic_numbers) return graph._replace(nodes=nodes)
[docs] def per_species_rescale( num_types: int, field: str, *, types_field: str = None, out_field: str = None, shifts: jax.typing.ArrayLike = None, scales: jax.typing.ArrayLike = None, ) -> gcnn_modules.IndexedRescale: types_field = types_field or ("nodes", gcnn_keys.SPECIES) return gcnn_modules.IndexedRescale( num_types, index_field=types_field, field=field, out_field=out_field, shifts=shifts, scales=scales, )