Source code for tensorial.signals.functions

import abc

import jax
import jax.numpy as jnp


[docs] class Function: """Base class for functions""" def __add__(self, other) -> "Sum": return Sum((self, other)) def __call__(self, x): return self.evaluate(x)
[docs] @abc.abstractmethod def evaluate(self, x): """Evaluate the function at point `x`"""
[docs] class DiracDelta(Function): """A Dirac delta with an optional weight""" def __init__(self, pos, weight=1.0): self.pos = pos self.weight = weight
[docs] def evaluate(self, x): return jax.lax.cond(not (self.pos - x).any(), lambda: self.weight, lambda: 0.0)
[docs] class IsotropicGaussian(Function): """A 3D Gaussian with an optional weight and scalar sigma""" def __init__(self, pos, sigma, weight=1.0) -> None: super().__init__() self.pos = pos self.sigma = sigma self.weight = weight
[docs] def evaluate(self, x): return ( self.weight / (jnp.sqrt(2 * jnp.pi)) * jnp.exp(-(jnp.sum((x - self.pos) ** 2)) / (2 * self.sigma**2)) )
[docs] class Sum(Function): """A sum of other functions""" def __init__(self, functions: tuple) -> None: super().__init__() transformed = [] for func in functions: if isinstance(func, Sum): transformed.extend(func.functions) else: transformed.append(func) self.functions = tuple(transformed)
[docs] def evaluate(self, x): return jnp.sum(func(x) for func in self.functions)