Source code for tensorial.signals.expansion

"""Module for functions performing expansion of functions with a basis"""

import functools

import jax.numpy as jnp

from . import bases, functions


[docs] @functools.singledispatch def expand( # pylint: disable=unused-argument basis: bases.RadialSphericalBasis, function: functions.Function ) -> jnp.array: """Expand a function in the given basis"""
[docs] @expand.register def expand_(basis: bases.SimpleRadialSphericalBasis, function: functions.Function) -> jnp.array: if isinstance(function, functions.DiracDelta): return function.weight * basis.evaluate(function.pos) if isinstance(function, functions.Sum): return jnp.sum(expand(basis, function) for function in function.functions) raise TypeError(f"Unsupported function {function.__class__.__name__}")