Source code for tensorial.signals.radials

import abc
from collections.abc import Callable
import math

import e3nn_jax
import jax
import jax.numpy as jnp
from typing_extensions import override

import tensorial


[docs] class RadialBasis(tensorial.Attr): """A set of radial basis functions""" _number: int _domain: tuple[float, float] def __init__(self, number: int, domain=(0.0, jnp.inf)): self._number = number self._domain = domain super().__init__(number * e3nn_jax.Irrep(0, p=1)) @property def number(self) -> int: """Get the number of radial functions in the basis""" return self._number def __len__(self): """The total number of radial functions""" return self._number # @abc.abstractmethod # def __getitem__(self, n: int) -> Callable: # """Get the radial function with index `n`""" @property def domain(self) -> tuple[float, float]: return self._domain
[docs] @abc.abstractmethod def evaluate(self, radius): """Evaluate the radial basis at `r`"""
[docs] @override def create_tensor(self, value) -> jnp.array: return super().create_tensor(self.evaluate(value)) # pylint: disable=not-callable
[docs] class E3nnRadial(RadialBasis): """Select a radial function from the one-hot linspace built into e3nn-jax see: https://e3nn-jax.readthedocs.io/en/latest/api/radial.html """ _basis: Callable[[float], jnp.array] _cutoff: float def __init__(self, basis: str, max_radius: float, number: int, *, cutoff=None, min_radius=0.0): super().__init__(number, domain=(min_radius, max_radius)) self._basis = basis self._cutoff = cutoff @property def basis(self) -> str: return self._basis @property def cutoff(self) -> bool | None: return self._cutoff
[docs] @override def evaluate(self, radius): return e3nn_jax.soft_one_hot_linspace( radius, start=self.domain[0], end=self.domain[1], number=self.number, basis=self.basis, cutoff=self.cutoff, )
[docs] class E3nnPolyEnvelope(RadialBasis): """Polynomial envelope that can be used to make a radial basis smoothly approach zero at the cutoff""" _radials: RadialBasis _smoothing_start: float _smoothing_width: float _envelope: Callable[[float], jax.Array] def __init__(self, basis: RadialBasis, smoothing_start: float, n0: int, n1: int): super().__init__(basis.number, basis.domain) if smoothing_start < basis.domain[0] or smoothing_start > basis.domain[1]: raise ValueError( f"The start of the smoothing envelope ({smoothing_start}) must be within the " f"domain ({basis.domain}) of the radial basis" ) self._radials = basis self._smoothing_start = smoothing_start self._smoothing_width = basis.domain[1] - smoothing_start self._envelope = e3nn_jax.poly_envelope(n0, n1, self._smoothing_width)
[docs] def evaluate(self, radius): values = self._radials.evaluate(radius) mask = radius >= self._smoothing_start # Calculate the envelope for r values in the masked range envelope = self._envelope(radius[mask] - self._smoothing_start) # Multiply the values by the envelope, expanding the envelope to repeat by the number of # radials values = values.at[mask].set( values[mask, :] * envelope[:, jnp.newaxis].repeat(self.number, axis=1) ) return values
[docs] class OrthoBasis(RadialBasis): radial_samples: jax.Array radial_step: jax.Array area_samples: jax.Array f_samples: jax.Array def __init__(self, radials: RadialBasis, n_samples: int): super().__init__(radials.number, radials.domain) self.radial_samples = jnp.linspace(radials.domain[0], radials.domain[1], n_samples) self.radial_step = self.radial_samples[1] - self.radial_samples[0] non_orthogonal_samples = radials.evaluate(self.radial_samples) self.area_samples = 4 * math.pi * self.radial_samples * self.radial_samples self.f_samples = jnp.zeros_like(non_orthogonal_samples) u0 = non_orthogonal_samples[:, 0] self.f_samples = self.f_samples.at[:, 0].set(u0 / self.norm(u0)) # self.f_samples[:, 0] = u0 / self.norm(u0) for i in range(1, self.number): ui = non_orthogonal_samples[:, i] for j in range(i): uj = self.f_samples[:, j] ui -= self.inner_product(uj, ui) / self.inner_product(uj, uj) * uj self.f_samples = self.f_samples.at[:, i].set(ui / self.norm(ui))
[docs] @override def evaluate(self, radius): r_normalized = radius / self.radial_step r_normalized_floor_int = jnp.floor(r_normalized).astype(jnp.int64) # Get the indices of the samples just below the values of r indices_low = jnp.minimum(r_normalized_floor_int, jnp.array([len(self.radial_samples) - 2])) # Get what fraction through the samples we should be at r_remainder_normalized = r_normalized - indices_low r_remainder_normalized = r_remainder_normalized[:, jnp.newaxis].repeat(self.number, axis=1) low_samples = self.f_samples[indices_low, :] high_samples = self.f_samples[indices_low + 1, :] ret = low_samples * (1 - r_remainder_normalized) + high_samples * r_remainder_normalized return ret
[docs] def inner_product(self, val_a, val_b): return jnp.trapezoid(val_a * val_b * self.area_samples, self.radial_samples)
[docs] def norm(self, val): return jnp.sqrt(self.inner_product(val, val))