Source code for tensorial.signals.bases

import abc

import e3nn_jax as e3j
import jax.numpy as jnp

import tensorial

from . import radials


[docs] class SphericalBasis(tensorial.Attr): """A set of spherical harmonics basis functions""" def __init__(self, l_max: int, p_val=1, p_arg=-1): self._l_max = l_max self._p_val = p_val self._p_arg = p_arg super().__init__(e3j.s2_irreps(l_max, p_val, p_arg)) @property def l_max(self) -> int: return self._l_max @property def p_val(self) -> int: return self._p_val @property def p_arg(self) -> int: return self._p_arg
[docs] def evaluate(self, x) -> e3j.IrrepsArray: """Evaluate the spherical harmonics at the passed values. Warning: It is assumed that the values are located on the unit sphere (i.e. normalised vectors), no check is made to enforce this. """ # * 2 # * math.sqrt(math.pi) return e3j.spherical_harmonics(self.irreps, x, normalize=True, normalization="integral")
[docs] def create_tensor(self, value) -> jnp.array: return self.evaluate(value)
[docs] class RadialSphericalBasis(tensorial.Attr): """A combined basis of a set of radial functions and spherical harmonics"""
[docs] def create_tensor(self, value: jnp.array) -> jnp.array: """Create the signal that represents the expansion of the signal function in this basis""" return self.evaluate(value)
[docs] @abc.abstractmethod def evaluate(self, value): """Evaluate the basis at the passed value"""
[docs] class SimpleRadialSphericalBasis(RadialSphericalBasis): def __init__(self, radial: radials.RadialBasis, spherical: SphericalBasis): self.radial = radial self.spherical = spherical num_radials = len(self.radial) super().__init__(spherical.irreps.repeat(num_radials))
[docs] def evaluate(self, value): """Evaluate the basis functions at the given value""" angular = self.spherical.evaluate(value).array r = jnp.linalg.norm(value, axis=-1) radial = self.radial.evaluate(r) return jnp.einsum("...i,...j->...ij", radial, angular)
[docs] def expand(self, x: jnp.array, coefficients: jnp.array): basis_values = self.evaluate(x) return jnp.einsum("ij,...ij->...", coefficients, basis_values)