from typing import Literal
import beartype
import e3nn_jax as e3j
import jax
import jax.numpy as jnp
import jaxtyping as jt
from jaxtyping import Array, Float, Int
from tensorial.typing import IrrepsArrayShape
from . import base
__all__ = "SphericalHarmonic", "CartesianTensor", "NoOp", "OneHot", "AsIrreps"
[docs]
class NoOp(base.Attr):
"""An attribute that keeps IrrepsArrays with specified irreps unchanged"""
def _validate(self, value):
assert isinstance(value, e3j.IrrepsArray), "Expected an IrrepsArray"
assert value.irreps == self.irreps, "Irreps mismatch"
[docs]
def create_tensor(self, value: e3j.IrrepsArray) -> e3j.IrrepsArray:
self._validate(value)
return value
[docs]
def from_tensor(self, tensor: e3j.IrrepsArray) -> e3j.IrrepsArray:
self._validate(tensor)
return tensor
[docs]
class AsIrreps(base.Attr):
def _validate(self, value):
assert isinstance(value, jnp.ndarray), "Expected a jnp.ndarray"
assert value.shape[-1] == self.irreps.dim, "Dimension mismatch"
[docs]
@jt.jaxtyped(typechecker=beartype.beartype)
def create_tensor(self, value: jt.ArrayLike) -> e3j.IrrepsArray:
self._validate(value)
return e3j.IrrepsArray(self.irreps, value)
[docs]
@jt.jaxtyped(typechecker=beartype.beartype)
def from_tensor(self, tensor: e3j.IrrepsArray) -> e3j.IrrepsArray:
assert tensor.irreps == self.irreps, "Irreps mismatch"
return tensor
[docs]
class SphericalHarmonic(base.Attr):
"""An attribute that is the spherical harmonics evaluated as some values"""
normalise: bool
normalisation: Literal["integral", "component", "norm"] | None = None
algorithm: tuple[str] | None = None
def __init__(
self,
irreps,
normalise,
normalisation: Literal["integral", "component", "norm"] | None = None,
*,
algorithm: tuple[str] = None,
):
super().__init__(irreps)
self.normalise = normalise
self.normalisation = normalisation
self.algorithm = algorithm
[docs]
def create_tensor(self, value: jax.Array | e3j.IrrepsArray) -> jnp.array:
return e3j.spherical_harmonics(
self.irreps,
base.as_array(value),
normalize=self.normalise,
normalization=self.normalisation,
algorithm=self.algorithm,
)
[docs]
class OneHot(base.Attr):
"""One-hot encoding as a direct sum of even scalars"""
def __init__(self, num_classes: int):
super().__init__(num_classes * e3j.Irrep(0, 1))
@property
def num_classes(self) -> int:
mul_irrep = self.irreps[0]
if isinstance(mul_irrep, e3j.MulIrrep):
return mul_irrep.mul
raise ValueError("Expected self.irreps to contain a MulIrrep.")
[docs]
@jt.jaxtyped(typechecker=beartype.beartype)
def create_tensor(self, value: Int[Array, "n_vals"]) -> IrrepsArrayShape["n_node num_classes"]:
return e3j.IrrepsArray(self.irreps, jax.nn.one_hot(value, self.num_classes))
[docs]
class CartesianTensor(base.Attr):
formula: str
keep_ir: e3j.Irreps | list[e3j.Irrep] | None
irreps_dict: dict
change_of_basis: jax.Array
_indices: str
def __init__(self, formula: str, keep_ir=None, **irreps_dict) -> None:
self.formula = formula
self.keep_ir = keep_ir
self.irreps_dict = irreps_dict
self._indices = formula.split("=")[0].replace("-", "")
# Construct the change of basis arrays
cob = e3j.reduced_tensor_product_basis(formula, keep_ir=self.keep_ir, **self.irreps_dict)
self.change_of_basis = cob.array
super().__init__(cob.irreps)
[docs]
@jt.jaxtyped(typechecker=beartype.beartype)
def create_tensor(self, value: jt.ArrayLike) -> e3j.IrrepsArray:
return super().create_tensor( # pylint: disable=not-callable
jnp.einsum("ij,ijz->z", value, self.change_of_basis)
)
[docs]
@jt.jaxtyped(typechecker=beartype.beartype)
def from_tensor(
self,
tensor: IrrepsArrayShape["irreps"] | IrrepsArrayShape["batch irreps"],
) -> Float[jax.Array, "..."] | Float[jax.Array, "batch ..."]:
"""Take an irrep tensor and perform the change of basis transformation back to a Cartesian
tensor
Args:
tensor: the irrep tensor
Returns:
the Cartesian tensor
"""
rot = self.change_of_basis.reshape(-1, self.change_of_basis.shape[-1])
cartesian = base.as_array(tensor) @ rot.T
return cartesian.reshape((*tensor.shape[:-1], *self.change_of_basis.shape[:-1]))