Source code for tensorial.base

from collections.abc import Mapping
import functools
from typing import Any

import beartype
import e3nn_jax as e3j
import equinox
from flax import linen
import jax
import jax.numpy as jnp
import jaxtyping as jt
import numpy as np
from reax.utils import arrays

from tensorial.typing import IntoIrreps

__all__ = (
    "IrrepsObj",
    "IrrepsTree",
    "Attr",
    "create",
    "create_tensor",
    "irreps",
    "get",
    "Tensorial",
    "tensorial_attrs",
    "from_tensor",
    "as_array",
)


Array = jax.typing.ArrayLike


def atleast_1d(arr, np_=jnp) -> jax.Array | np.ndarray:
    np_ = np_ if np_ is not None else arrays.infer_backend(arr)
    arr = np_.asarray(arr)
    return arr if np_.ndim(arr) >= 1 else np_.reshape(arr, -1)


[docs] def as_array(arr: jt.ArrayLike | e3j.IrrepsArray) -> jax.Array: """Get a standard JAX array given either: 1. a numpy.ndarray 2. an e3nn_jax.IrrepsArray, or 3. a jax.Array (in which case it is returned unmodified) Args: arr: the array to get the value for Returns: the JAX array """ if isinstance(arr, e3j.IrrepsArray): return arr.array return jnp.asarray(arr)
[docs] class Attr(equinox.Module): """Irreps object attribute""" irreps: e3j.Irreps def __init__(self, irreps: IntoIrreps) -> None: # pylint: disable=redefined-outer-name self.irreps = e3j.Irreps(irreps)
[docs] @jt.jaxtyped(typechecker=beartype.beartype) def create_tensor(self, value: Any) -> e3j.IrrepsArray: return e3j.IrrepsArray(self.irreps, atleast_1d(value))
[docs] @jt.jaxtyped(typechecker=beartype.beartype) def from_tensor(self, tensor: e3j.IrrepsArray) -> Any: """This can be overwritten to perform the backward transform of `create_tensor`""" return tensor
[docs] class IrrepsObj: """An object that contains tensorial attributes."""
Tensorial = Attr | IrrepsObj | type(IrrepsObj) | dict | linen.FrozenDict | e3j.Irreps IrrepsTree = IrrepsObj | dict[str, Tensorial] ValueType = Any | list["ValueType"] | dict[str, "ValueType"]
[docs] @functools.singledispatch def create(tensorial: Tensorial, value: Mapping): if not issubclass(tensorial, IrrepsObj): raise TypeError(tensorial.__class__.__name__) value_dict = {} for name, val in tensorial_attrs(tensorial).items(): value_dict[name] = create(val, value[name]) return value_dict
@create.register def _(attr: Attr, value) -> e3j.IrrepsArray: """Leaf, so create the tensor""" return create_tensor(attr, value) @create.register def _(attr: IrrepsObj, value) -> e3j.IrrepsArray: """Leaf, so create the tensor""" return create_tensor(attr, value) @create.register def _(attr: e3j.Irreps, value) -> e3j.IrrepsArray: """Leaf, so create the tensor""" return create_tensor(attr, value)
[docs] @functools.singledispatch def irreps(tensorial: Tensorial) -> e3j.Irreps: """Get the irreps for a tensorial type""" if not issubclass(tensorial, IrrepsObj): raise TypeError(tensorial.__class__.__name__) # IrrepsObj code: total_irreps = None for name, val in tensorial_attrs(tensorial).items(): try: total_irreps = val.irreps if total_irreps is None else total_irreps + val.irreps except AttributeError as exc: raise AttributeError(f"Failed to get irreps for {name}") from exc return total_irreps
@irreps.register def _irreps_attr(attr: Attr) -> e3j.Irreps: return attr.irreps @irreps.register def _irreps_irreps(tensorial: e3j.Irreps) -> e3j.Irreps: return tensorial
[docs] @functools.singledispatch def create_tensor(tensorial: Tensorial, value: ValueType) -> e3j.IrrepsArray: """Create a tensor for a tensorial type""" try: # issubclass can fail if the value is not a class, so we guard against that here # and raise later with a more meaningful message is_subclass = issubclass(tensorial, IrrepsObj) except TypeError: pass # Will raise at bottom of function else: if is_subclass: return create_tensor(tensorial_attrs(tensorial), value) raise TypeError(f"Unrecognised tensorial type: {tensorial.__class__.__name__}")
@create_tensor.register def _create_tensor_irreps_obj(tensorial: IrrepsObj, value) -> e3j.IrrepsArray: return create_tensor(tensorial_attrs(tensorial), value) @create_tensor.register def _create_tensor_dict(tensorial: dict, value) -> e3j.IrrepsArray: return e3j.concatenate( [create_tensor(attr, value[key]) for key, attr in tensorial.items()], ) @create_tensor.register def _create_tensor_frozen_dict(tensorial: linen.FrozenDict, value): return create_tensor(tensorial.unfreeze(), value) @create_tensor.register def _create_tensor_irreps( # pylint: disable=redefined-outer-name irreps: e3j.Irreps, value: Array ) -> e3j.IrrepsArray: return e3j.IrrepsArray(irreps, value) @create_tensor.register def _create_tensor_str( # pylint: disable=redefined-outer-name irreps: str, value: Array ) -> e3j.IrrepsArray: return e3j.IrrepsArray(irreps, value) @create_tensor.register def _create_tensor_attr(attr: Attr, value) -> e3j.IrrepsArray: return attr.create_tensor(value)
[docs] @functools.singledispatch def from_tensor(tensorial: Tensorial, value) -> ValueType: """Create a tensor for a tensorial type""" try: # issubclass can fail if the value is a class, so we guard against that here # and raise later with a more meaningful message is_subclass = issubclass(tensorial, IrrepsObj) except TypeError: pass # Will raise at bottom of function else: if is_subclass: return from_tensor(tensorial_attrs(tensorial), value) raise TypeError(f"Unrecognised tensorial type: {tensorial.__class__.__name__}")
@from_tensor.register def _from_tensor_irreps_obj(tensorial: IrrepsObj, value) -> dict[str, ValueType]: return from_tensor(tensorial_attrs(tensorial), value) @from_tensor.register def _from_tensor_dict(tensorial: dict, value: Array) -> dict[str, ValueType]: dims = jnp.array(tuple(map(lambda val: irreps(val).dim, tensorial.values()))) split_points = jnp.array(tuple(jnp.sum(dims[:i]) for i in range(len(dims) - 1))) split_value = jnp.split(value, split_points) return { key: from_tensor(dict_value, array_value) for array_value, (key, dict_value) in zip(split_value, tensorial_attrs(tensorial).items()) } @from_tensor.register def _from_tensor_frozen_dict(tensorial: linen.FrozenDict, value): return from_tensor(tensorial.unfreeze(), value) @from_tensor.register def _from_tensor_irreps( # pylint: disable=redefined-outer-name irreps: e3j.Irreps, value: e3j.IrrepsArray ) -> e3j.IrrepsArray: # Nothing to do if not irreps == value.irreps: raise ValueError(f"Irreps mismatch: {irreps} != {value.irreps}") return value @from_tensor.register def _from_tensor(attr: Attr, value) -> e3j.IrrepsArray: return attr.from_tensor(value)
[docs] @functools.singledispatch def tensorial_attrs(irreps_obj) -> dict[str, Tensorial]: if issubclass(irreps_obj, IrrepsObj): return { name: val for name, val in vars(irreps_obj).items() if not (name.startswith("_") or callable(val)) } raise TypeError(irreps_obj.__class__.__name__)
@tensorial_attrs.register def _tensorial_attrs_irreps_obj(irreps_obj: IrrepsObj) -> dict[str, Tensorial]: """Get the irrep attributes for the passed object""" attrs = tensorial_attrs(type(irreps_obj)) attrs.update( { name: val for name, val in vars(irreps_obj).items() if not (name.startswith("_") or callable(val)) } ) return attrs @tensorial_attrs.register def _tensorial_attrs_dict(irreps_obj: dict) -> dict[str, Tensorial]: return {name: val for name, val in irreps_obj.items() if not name.startswith("_")} @tensorial_attrs.register def _tensorial_attrs_frozen_dict(irreps_obj: linen.FrozenDict) -> dict[str, Tensorial]: return tensorial_attrs(irreps_obj.unfreeze())
[docs] def get(irreps_obj: type[IrrepsObj], tensor: Array, attr_name: str = None) -> Array: if not attr_name: return tensor attrs = tensorial_attrs(irreps_obj) idx = list(attrs.keys()).index(attr_name) # Get the linear start and end index of the tensor corresponding to the passed attribute begin = sum(irreps(attr).dim for attr in list(attrs.values())[:idx]) end = begin + irreps(attrs[attr_name]).dim return tensor[begin:end]