tensorial package#

Subpackages#

Submodules#

tensorial.base module#

class tensorial.base.Attr(irreps)[source]#

Bases: Module

Irreps object attribute

Parameters:

irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

create_tensor(value)[source]#
Parameters:

value (Any)

Return type:

IrrepsArray

from_tensor(tensor)[source]#

This can be overwritten to perform the backward transform of create_tensor

Parameters:

tensor (IrrepsArray)

Return type:

Any

irreps: Irreps#
class tensorial.base.IrrepsObj[source]#

Bases: object

An object that contains tensorial attributes.

tensorial.base.as_array(arr)[source]#
Get a standard JAX array given either:
  1. a numpy.ndarray

  2. an e3nn_jax.IrrepsArray, or

  3. a jax.Array (in which case it is returned unmodified)

Parameters:
  • arr – the array to get the value for

  • arr (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, IrrepsArray])

Return type:

Array

Returns:

the JAX array

tensorial.base.create(tensorial, value)[source]#
tensorial.base.create(attr, value)
tensorial.base.create(attr, value)
tensorial.base.create(attr, value)
Parameters:
  • tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

  • value (Mapping)

tensorial.base.create_tensor(tensorial, value)[source]#
tensorial.base.create_tensor(tensorial, value)
tensorial.base.create_tensor(tensorial, value)
tensorial.base.create_tensor(tensorial, value)
tensorial.base.create_tensor(irreps, value)
tensorial.base.create_tensor(attr, value)

Create a tensor for a tensorial type

Parameters:
  • tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

  • value (Any | list[Any | list[ValueType] | dict[str, ValueType]] | dict[str, Any | list[ValueType] | dict[str, ValueType]])

Return type:

IrrepsArray

tensorial.base.from_tensor(tensorial, value)[source]#
tensorial.base.from_tensor(tensorial, value)
tensorial.base.from_tensor(tensorial, value)
tensorial.base.from_tensor(tensorial, value)
tensorial.base.from_tensor(irreps, value)
tensorial.base.from_tensor(attr, value)

Create a tensor for a tensorial type

Parameters:

tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

Return type:

Any | list[Any | list[ValueType] | dict[str, ValueType]] | dict[str, Any | list[ValueType] | dict[str, ValueType]]

tensorial.base.get(irreps_obj, tensor, attr_name=None)[source]#
Parameters:
  • irreps_obj (type[IrrepsObj])

  • tensor (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray])

  • attr_name (str)

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

tensorial.base.irreps(tensorial)[source]#
tensorial.base.irreps(attr)
tensorial.base.irreps(tensorial)

Get the irreps for a tensorial type

Parameters:

tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

Return type:

Irreps

tensorial.base.tensorial_attrs(irreps_obj)[source]#
tensorial.base.tensorial_attrs(irreps_obj)
tensorial.base.tensorial_attrs(irreps_obj)
tensorial.base.tensorial_attrs(irreps_obj)
Return type:

dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps]

tensorial.config module#

tensorial.config.instantiate(cfg, **kwargs)[source]#

Given an omegaconf configuration, instantiate the corresponding object

Parameters:

cfg (OmegaConf)

Return type:

Any

tensorial.nn module#

class tensorial.nn.Sequential(layers, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Applies a sequential chain of modules just like flax.linen.Sequential _except_ that flax’s version will expand any tuples that it receives when calling the next layer. This doesn’t play nice with types that subclass tuple, for example, jraph.GraphsTuple, because the layers expect to get a GraphsTuple, not the individual values that make it up.

Our behaviour is the same as flax.linen.Sequential if we get a tuple, but any subclasses thereof are kept intact when calling the next layer.

Parameters:
  • layers (Sequence[Module | partial])

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

layers: Sequence[Module | partial]#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

Return type:

None

tensorial.nn_utils module#

tensorial.nn_utils.get_jaxnn_activation(func)[source]#

Returns the activation function with name form the jax.nn module

Parameters:
  • func – the name of the function (as used in jax.nn)

  • func (Callable[[Array], Array])

Return type:

Callable[[Array], Array]

Returns:

the activation function

tensorial.nn_utils.prepare_mask(mask, array)[source]#

Prepare a mask for use with jnp.where(mask, array, …). This needs to be done to make sure the mask is of the right shape to be compatible with such an operation. The other alternative is

jnp.where(mask, array.T, ...).T

but this sometimes leads to creating a copy when doing one or both of the transposes. I’m not sure why, but this approach seems to avoid the problem.

Parameters:
  • mask – the mask to prepare

  • array – the array the mask will be applied to

  • mask (Bool[Array, 'n_elements'])

  • array ('])

Return type:

']

Returns:

the prepared mask, typically this is just padded with extra dimensions (or reduced)

tensorial.nn_utils.vwhere(values, types)[source]#
Parameters:
  • values (Array)

  • types (Array)

Return type:

Array

tensorial.tensors module#

class tensorial.tensors.AsIrreps(irreps)[source]#

Bases: Attr

Parameters:

irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

create_tensor(value)[source]#
Parameters:

value (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray])

Return type:

IrrepsArray

from_tensor(tensor)[source]#

This can be overwritten to perform the backward transform of create_tensor

Parameters:

tensor (IrrepsArray)

Return type:

IrrepsArray

class tensorial.tensors.CartesianTensor(formula, keep_ir=None, **irreps_dict)[source]#

Bases: Attr

Parameters:

formula (str)

change_of_basis: Array#
create_tensor(value)[source]#
Parameters:

value (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray])

Return type:

IrrepsArray

formula: str#
from_tensor(tensor)[source]#

Take an irrep tensor and perform the change of basis transformation back to a Cartesian tensor

Parameters:
  • tensor – the irrep tensor

  • tensor (Float[IrrepsArray, 'irreps'] | Float[IrrepsArray, 'batch irreps'])

Return type:

'] | ']

Returns:

the Cartesian tensor

irreps_dict: dict#
keep_ir: Irreps | list[Irrep] | None#
class tensorial.tensors.NoOp(irreps)[source]#

Bases: Attr

An attribute that keeps IrrepsArrays with specified irreps unchanged

Parameters:

irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

create_tensor(value)[source]#
Parameters:

value (IrrepsArray)

Return type:

IrrepsArray

from_tensor(tensor)[source]#

This can be overwritten to perform the backward transform of create_tensor

Parameters:

tensor (IrrepsArray)

Return type:

IrrepsArray

class tensorial.tensors.OneHot(num_classes)[source]#

Bases: Attr

One-hot encoding as a direct sum of even scalars

Parameters:

num_classes (int)

create_tensor(value)[source]#
Parameters:

value (Union[Int[Array, 'n_vals'], Int[ndarray, 'n_vals']])

Return type:

IrrepsArray

property num_classes: int#
class tensorial.tensors.SphericalHarmonic(irreps, normalise, normalisation=None, *, algorithm=None)[source]#

Bases: Attr

An attribute that is the spherical harmonics evaluated as some values

Parameters:
  • normalisation (Optional[Literal['integral', 'component', 'norm']])

  • algorithm (tuple[str])

algorithm: tuple[str] | None = None#
create_tensor(value)[source]#
Parameters:

value (Array | IrrepsArray)

Return type:

array

normalisation: Literal['integral', 'component', 'norm'] | None = None#
normalise: bool#

tensorial.typing module#

tensorial.utils module#

tensorial.utils.infer_backend(pytree)[source]#

Try to infer a backend from the passed pytree

Return type:

ModuleType

tensorial.utils.ones(irreps, leading_shape=(), dtype=None, np_=<module 'jax.numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/tensorial/envs/stable/lib/python3.13/site-packages/jax/numpy/__init__.py'>)[source]#

Create an IrrepsArray of ones.

Parameters:
  • irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

  • leading_shape (tuple)

  • dtype (dtype)

Return type:

IrrepsArray

tensorial.utils.ones_like(irreps_array)[source]#

Create an IrrepsArray of ones with the same shape as another IrrepsArray.

Parameters:

irreps_array (IrrepsArray)

Return type:

IrrepsArray

tensorial.utils.zeros(irreps, leading_shape=(), dtype=None, np_=<module 'jax.numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/tensorial/envs/stable/lib/python3.13/site-packages/jax/numpy/__init__.py'>)[source]#

Create an IrrepsArray of zeros.

Parameters:
  • irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

  • leading_shape (tuple)

  • dtype (dtype)

Return type:

IrrepsArray

tensorial.utils.zeros_like(irreps_array)[source]#

Create an IrrepsArray of zeros with the same shape as another IrrepsArray.

Parameters:

irreps_array (IrrepsArray)

Return type:

IrrepsArray

Module contents#

Library for machine learning on physical tensors

class tensorial.AsIrreps(irreps)[source]#

Bases: Attr

Parameters:

irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

create_tensor(value)[source]#
Parameters:

value (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray])

Return type:

IrrepsArray

from_tensor(tensor)[source]#

This can be overwritten to perform the backward transform of create_tensor

Parameters:

tensor (IrrepsArray)

Return type:

IrrepsArray

class tensorial.Attr(irreps)[source]#

Bases: Module

Irreps object attribute

Parameters:

irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

create_tensor(value)[source]#
Parameters:

value (Any)

Return type:

IrrepsArray

from_tensor(tensor)[source]#

This can be overwritten to perform the backward transform of create_tensor

Parameters:

tensor (IrrepsArray)

Return type:

Any

irreps: Irreps#
class tensorial.CartesianTensor(formula, keep_ir=None, **irreps_dict)[source]#

Bases: Attr

Parameters:

formula (str)

change_of_basis: Array#
create_tensor(value)[source]#
Parameters:

value (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray])

Return type:

IrrepsArray

formula: str#
from_tensor(tensor)[source]#

Take an irrep tensor and perform the change of basis transformation back to a Cartesian tensor

Parameters:
  • tensor – the irrep tensor

  • tensor (Float[IrrepsArray, 'irreps'] | Float[IrrepsArray, 'batch irreps'])

Return type:

'] | ']

Returns:

the Cartesian tensor

irreps_dict: dict#
keep_ir: Irreps | list[Irrep] | None#
class tensorial.IrrepsObj[source]#

Bases: object

An object that contains tensorial attributes.

class tensorial.NoOp(irreps)[source]#

Bases: Attr

An attribute that keeps IrrepsArrays with specified irreps unchanged

Parameters:

irreps (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

create_tensor(value)[source]#
Parameters:

value (IrrepsArray)

Return type:

IrrepsArray

from_tensor(tensor)[source]#

This can be overwritten to perform the backward transform of create_tensor

Parameters:

tensor (IrrepsArray)

Return type:

IrrepsArray

class tensorial.OneHot(num_classes)[source]#

Bases: Attr

One-hot encoding as a direct sum of even scalars

Parameters:

num_classes (int)

create_tensor(value)[source]#
Parameters:

value (Union[Int[Array, 'n_vals'], Int[ndarray, 'n_vals']])

Return type:

IrrepsArray

property num_classes: int#
class tensorial.ReaxModule(model, loss_fn, optimizer, scheduler=None, metrics=None, jit=True, donate_graph=False, output=('predictions', 'targets'))[source]#

Bases: Module[GraphsTuple, GraphsTuple]

Tensorial REAX module.

Parameters:
  • model (Module)

  • loss_fn (Callable[[GraphsTuple, GraphsTuple], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

Init function.

Parameters:
  • model (Module)

  • loss_fn (Callable[[GraphsTuple, GraphsTuple], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

static calculate_metrics(predictions, targets, metrics)[source]#
Parameters:
  • predictions (GraphsTuple)

  • targets (GraphsTuple)

  • metrics (dict[str, Metric | str])

Return type:

dict[str, Metric]

configure_model(_stage, batch, /)[source]#

Called at the beginning of each stage.

A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing.

Parameters:

_stage (Stage)

configure_optimizers()[source]#

Create the optimizer(s) to use during training.

property debug: bool#
on_before_optimizer_step(_optimizer, grad, /)[source]#

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If clipping gradients, the gradients will not have been clipped yet.

Parameters:
  • optimizer – Current optimizer being used.

  • grad – The gradients dictionary from JAX

  • _optimizer (Optimizer)

  • grad (dict[str, Any])

predict_step(batch, _batch_idx, /)[source]#

Make a model prediction and return the result.

Parameters:
  • batch (GraphsTuple)

  • _batch_idx (int)

Return type:

GraphsTuple

static step(params, inputs, _targets, model, loss_fn, metrics=None, output=())[source]#

Calculate loss and, optionally metrics.

Parameters:
  • params (PyTree)

  • inputs (GraphsTuple)

  • _targets (GraphsTuple)

  • model (Callable[[PyTree, GraphsTuple], GraphsTuple])

  • loss_fn (Callable)

  • metrics (MetricCollection | None)

  • output (tuple[str, ...])

Return type:

tuple[Array, dict]

test_step(batch, _batch_idx, /)[source]#

Test step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput | None

training_step(batch, _batch_idx, /)[source]#

Train step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput

validation_step(batch, _batch_idx, /)[source]#

Validate step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput | None

class tensorial.SphericalHarmonic(irreps, normalise, normalisation=None, *, algorithm=None)[source]#

Bases: Attr

An attribute that is the spherical harmonics evaluated as some values

Parameters:
  • normalisation (Optional[Literal['integral', 'component', 'norm']])

  • algorithm (tuple[str])

algorithm: tuple[str] | None = None#
create_tensor(value)[source]#
Parameters:

value (Array | IrrepsArray)

Return type:

array

normalisation: Literal['integral', 'component', 'norm'] | None = None#
normalise: bool#
tensorial.as_array(arr)[source]#
Get a standard JAX array given either:
  1. a numpy.ndarray

  2. an e3nn_jax.IrrepsArray, or

  3. a jax.Array (in which case it is returned unmodified)

Parameters:
  • arr – the array to get the value for

  • arr (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, IrrepsArray])

Return type:

Array

Returns:

the JAX array

tensorial.create(tensorial, value)[source]#
tensorial.create(attr, value)
tensorial.create(attr, value)
tensorial.create(attr, value)
Parameters:
  • tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

  • value (Mapping)

tensorial.create_tensor(tensorial, value)[source]#
tensorial.create_tensor(tensorial, value)
tensorial.create_tensor(tensorial, value)
tensorial.create_tensor(tensorial, value)
tensorial.create_tensor(irreps, value)
tensorial.create_tensor(attr, value)

Create a tensor for a tensorial type

Parameters:
  • tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

  • value (Any | list[Any | list[ValueType] | dict[str, ValueType]] | dict[str, Any | list[ValueType] | dict[str, ValueType]])

Return type:

IrrepsArray

tensorial.from_tensor(tensorial, value)[source]#
tensorial.from_tensor(tensorial, value)
tensorial.from_tensor(tensorial, value)
tensorial.from_tensor(tensorial, value)
tensorial.from_tensor(irreps, value)
tensorial.from_tensor(attr, value)

Create a tensor for a tensorial type

Parameters:

tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

Return type:

Any | list[Any | list[ValueType] | dict[str, ValueType]] | dict[str, Any | list[ValueType] | dict[str, ValueType]]

tensorial.get(irreps_obj, tensor, attr_name=None)[source]#
Parameters:
  • irreps_obj (type[IrrepsObj])

  • tensor (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray])

  • attr_name (str)

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

tensorial.irreps(tensorial)[source]#
tensorial.irreps(attr)
tensorial.irreps(tensorial)

Get the irreps for a tensorial type

Parameters:

tensorial (Attr | IrrepsObj | type | dict | FrozenDict | Irreps)

Return type:

Irreps

tensorial.tensorial_attrs(irreps_obj)[source]#
tensorial.tensorial_attrs(irreps_obj)
tensorial.tensorial_attrs(irreps_obj)
tensorial.tensorial_attrs(irreps_obj)
Return type:

dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps]