Source code for tensorial.utils
import types
import e3nn_jax as e3j
import jax
import jax.numpy as jnp
import numpy as np
from tensorial.typing import IntoIrreps
[docs]
def infer_backend(pytree) -> types.ModuleType:
"""Try to infer a backend from the passed pytree"""
any_numpy = any(isinstance(x, np.ndarray) for x in jax.tree_util.tree_leaves(pytree))
any_jax = any(isinstance(x, jax.Array) for x in jax.tree_util.tree_leaves(pytree))
if any_numpy and any_jax:
raise ValueError("Cannot mix numpy and jax arrays")
if any_numpy:
return np
if any_jax:
return jnp
return jnp
[docs]
def zeros(
irreps: IntoIrreps, leading_shape: tuple = (), dtype: jnp.dtype = None, np_=jnp
) -> e3j.IrrepsArray:
r"""Create an IrrepsArray of zeros."""
irreps = e3j.Irreps(irreps)
array = np_.zeros(leading_shape + (irreps.dim,), dtype=dtype)
return e3j.IrrepsArray(irreps, array, zero_flags=(True,) * len(irreps))
[docs]
def zeros_like(irreps_array: e3j.IrrepsArray) -> e3j.IrrepsArray:
r"""Create an IrrepsArray of zeros with the same shape as another IrrepsArray."""
np_ = infer_backend(irreps_array.array)
return zeros(irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype, np_=np_)
[docs]
def ones(
irreps: IntoIrreps, leading_shape: tuple = (), dtype: jnp.dtype = None, np_=jnp
) -> e3j.IrrepsArray:
r"""Create an IrrepsArray of ones."""
irreps = e3j.Irreps(irreps)
array = np_.ones(leading_shape + (irreps.dim,), dtype=dtype)
return e3j.IrrepsArray(irreps, array, zero_flags=(False,) * len(irreps))
[docs]
def ones_like(irreps_array: e3j.IrrepsArray) -> e3j.IrrepsArray:
r"""Create an IrrepsArray of ones with the same shape as another IrrepsArray."""
np_ = infer_backend(irreps_array.array)
return ones(irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype, np_=np_)