tensorial.gcnn package#

Subpackages#

Submodules#

tensorial.gcnn.calc module#

tensorial.gcnn.calc.cell_volume(cell_vectors, np_=None)[source]#

Computes the volume of a unit cell defined by its cell vectors.

The cell volume is calculated as the absolute value of the determinant of the matrix formed by the cell vectors. This is commonly used in crystallography and computational physics to determine the volume of a unit cell.

Parameters:
  • cell_vectors – A 3x3 matrix where each row represents a cell vector in three-dimensional space.

  • np_ – The numerical library backend to use for computation. If None, the backend is inferred from the input tensor.

  • cell_vectors (Union[Float[Array, '3 3'], Float[ndarray, '3 3']])

Return type:

Union[Array, ndarray]

Returns:

The volume of the unit cell as a scalar tensor.

Raises:

LinAlgError – If the cell vectors are linearly dependent, resulting in a zero determinant.

tensorial.gcnn.derivatives module#

class tensorial.gcnn.derivatives.Grad(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

The Grad class computes gradients of graph-based functions with respect to specified graph attributes. It enables automatic differentiation of operations defined on graph structures, such as computing how changes in node positions affect edge lengths or other graph properties. The class supports both scalar and vector-valued gradients and integrates with JAX for efficient computation.

Parameters:
  • func (Callable[[GraphsTuple], GraphsTuple])

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], list[Union[str, tuple[str, ...]]]])

  • out_field (Union[str, tuple[str, ...], list[Union[str, tuple[str, ...]]]])

  • sign (float)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

func: Callable[[GraphsTuple], GraphsTuple]#
name: str | None = None#
of: str | tuple[str, ...]#
out_field: str | tuple[str, ...] | list[str | tuple[str, ...]] = 'auto'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sign: float = 1.0#
wrt: str | tuple[str, ...] | list[str | tuple[str, ...]]#
class tensorial.gcnn.derivatives.Jacfwd(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Parameters:
  • func (Callable[[GraphsTuple], GraphsTuple])

  • of (Union[str, tuple[str, ...]])

  • wrt (str | Sequence[Union[str, tuple[str, ...]]])

  • out_field (str | Sequence[str])

  • sign (float)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

func: Callable[[GraphsTuple], GraphsTuple]#
name: str | None = None#
of: str | tuple[str, ...]#
out_field: str | Sequence[str] = 'auto'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sign: float = 1.0#
wrt: str | Sequence[str | tuple[str, ...]]#
class tensorial.gcnn.derivatives.Jacobian(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Parameters:
  • func (Callable[[GraphsTuple], GraphsTuple])

  • of (Union[str, tuple[str, ...]])

  • wrt (str | Sequence[Union[str, tuple[str, ...]]])

  • out_field (str | Sequence[str])

  • sign (float)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

func: Callable[[GraphsTuple], GraphsTuple]#
name: str | None = None#
of: str | tuple[str, ...]#
out_field: str | Sequence[str] = 'auto'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sign: float = 1.0#
wrt: str | Sequence[str | tuple[str, ...]]#
tensorial.gcnn.derivatives.grad(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.derivatives.hessian(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.derivatives.jacfwd(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.derivatives.jacobian(of, wrt, sign=1.0, has_aux=False)#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.derivatives.jacrev(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.keys module#

tensorial.gcnn.keys.predicted(key, delimiter='_')[source]#

Helper to create a ‘predicted’ key.

Parameters:
  • key (str)

  • delimiter (str)

Return type:

str

tensorial.gcnn.losses module#

class tensorial.gcnn.losses.GraphLoss(label)[source]#

Bases: Module

Parameters:

label (str)

label()[source]#

Get a label for this loss function

Return type:

str

class tensorial.gcnn.losses.Loss(loss_fn, targets, predictions=None, *, reduction='mean', label=None, mask_field=None)[source]#

Bases: GraphLoss

Simple loss function that passes values from the graph to a function taking numerical values such as optax losses

Parameters:
  • loss_fn (str | Callable[[Array, Array], Array])

  • targets (str)

  • predictions (str | None)

  • reduction (Literal['sum', 'mean'])

  • label (str)

  • mask_field (str | None)

Initializes the loss function wrapper with specified parameters.

This constructor sets up a loss function wrapper that can be used to compute loss values based on target and prediction fields. It supports various loss functions and provides options for reduction and masking.

Parameters:
  • loss_fn – The loss function to use, either as a string identifier or a callable implementing the PureLossFn protocol.

  • targets – The path to the target field in the data structure.

  • predictions – The path to the prediction field in the data structure. If None, the target field path will be used.

  • reduction – The reduction method to apply to the computed losses. Either “sum” or “mean”.

  • label – The label to use for the loss function. If None, the prediction field path will be used as the label.

  • mask_field – The path to the mask field in the data structure. If None, no masking will be applied.

  • loss_fn (str | Callable[[Array, Array], Array])

  • targets (str)

  • predictions (str | None)

  • reduction (Literal['sum', 'mean'])

  • label (str)

  • mask_field (str | None)

Raises:
  • ValueError – If the reduction method is not “sum” or “mean”.

  • TypeError – If the loss function is not a valid string or callable.

class tensorial.gcnn.losses.WeightedLoss(loss_fns, weights=None)[source]#

Bases: GraphLoss

Parameters:
  • loss_fns (Sequence[GraphLoss])

  • weights (Sequence[float] | None)

A weighted combination of multiple graph loss functions.

This class combines multiple graph loss functions with specified weights to create a composite loss function. The weights determine the contribution of each individual loss function to the final combined loss.

Parameters:
  • loss_fns – Sequence of graph loss functions to combine.

  • weights – Sequence of weights for each loss function. If None, all weights are set to 1.0.

  • loss_fns (Sequence[GraphLoss])

  • weights (Sequence[float] | None)

Raises:

ValueError – If any element in loss_fns is not a subclass of GraphLoss, or if the number of weights does not match the number of loss functions.

loss_with_contributions(predictions, target)[source]#
Parameters:
  • predictions (GraphsTuple)

  • target (GraphsTuple)

Return type:

tuple[Array, dict[str, float]]

property weights#

tensorial.gcnn.metrics module#

class tensorial.gcnn.metrics.GraphMetric(state=None)[source]#

Bases: Metric

Parameters:

state (Optional[Metric[TypeVar(OutT)]])

compute()[source]#

Compute the metric.

Return type:

TypeVar(OutT)

create(predictions, targets=None)[source]#

Create a new metric instance from data.

Parameters:
  • predictions (GraphsTuple)

  • targets (GraphsTuple | None)

Return type:

GraphMetric

property is_empty: bool#
mask_key: ClassVar[str | tuple[str, ...] | None] = 'auto'#
merge(other)[source]#

Merge the metric with data from another metric instance of the same type.

Parameters:

other (GraphMetric)

Return type:

GraphMetric

property metric: Metric[OutT] | None#
normalise_by: ClassVar[str | tuple[str, ...] | None] = None#
parent: ClassVar[Metric]#
pred_key: ClassVar[str | tuple[str, ...]]#
target_key: ClassVar[str | tuple[str, ...] | None] = None#
tensorial.gcnn.metrics.graph_metric(metric, predictions, targets=None, mask='auto', normalise_by=None)[source]#
Parameters:
  • metric (str | Metric | type[Metric])

  • predictions (Union[str, tuple[str, ...]])

  • targets (Union[str, tuple[str, ...], None])

  • mask (Union[str, tuple[str, ...], Literal['auto'], None])

  • normalise_by (Union[str, tuple[str, ...], None])

Return type:

GraphMetric

tensorial.gcnn.random module#

tensorial.gcnn.random.spatial_graph(rng_key, num_nodes=None, num_graphs=None, cutoff=0.4, nodes=None)[source]#

Create graph(s) with nodes that have random positions

Parameters:
  • rng_key (Array)

  • num_nodes (int)

  • nodes (dict[str, Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray], int], Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]]]] | None)

Return type:

GraphsTuple | list[GraphsTuple]

tensorial.gcnn.typing module#

tensorial.gcnn.utils module#

class tensorial.gcnn.utils.UpdateDict(updating)[source]#

Bases: MutableMapping

This class can be used to make updates to a dictionary without modifying the passed dictionary. Once all the updates are made, a new dictionary that is the result of the modifications can be retrieved using the _asdict() method.

Parameters:

updating (dict)

DELETED = ()#
tensorial.gcnn.utils.path_from_str(path_str, delimiter='.')[source]#

Split up a path string into a tuple of path components

Parameters:
  • path_str (Union[str, tuple[str, ...]])

  • delimiter (str)

Return type:

tuple[str, ...]

tensorial.gcnn.utils.path_to_str(path, delimiter='.')[source]#

Return a string representation of a tree path

Parameters:
  • path (Union[str, tuple[str, ...]])

  • delimiter (str)

Return type:

str

Module contents#

class tensorial.gcnn.EdgeVectors(as_irreps_arrays=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Create edge vectors from atomic positions. This will take into account the unit cell (if present)

Parameters:
  • as_irreps_arrays (bool)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

as_irreps_arrays: bool = False#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class tensorial.gcnn.EdgewiseDecoding(attrs, in_field='attributes', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Decode the direct sum of irreps stored in the in_field and store each tensor as a node value with key coming from the attrs.

Parameters:
  • attrs (IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps])

  • in_field (str)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

attrs: IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps]#
in_field: str = 'attributes'#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class tensorial.gcnn.EdgewiseEmbedding(attrs, out_field='attributes', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Parameters:
  • attrs (IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps])

  • out_field (str)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

attrs: IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps]#
name: str | None = None#
out_field: str = 'attributes'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
tensorial.gcnn.EdgewiseEncoding#

alias of EdgewiseEmbedding

class tensorial.gcnn.EdgewiseLinear(irreps_out, irreps_in=None, field='features', out_field='features', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Edgewise linear operation

Parameters:
  • irreps_out (str | Irreps)

  • irreps_in (Irreps | None)

  • field (str)

  • out_field (str | None)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

field: str = 'features'#
irreps_in: Irreps | None = None#
irreps_out: str | Irreps#
name: str | None = None#
out_field: str | None = 'features'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class tensorial.gcnn.Grad(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

The Grad class computes gradients of graph-based functions with respect to specified graph attributes. It enables automatic differentiation of operations defined on graph structures, such as computing how changes in node positions affect edge lengths or other graph properties. The class supports both scalar and vector-valued gradients and integrates with JAX for efficient computation.

Parameters:
  • func (Callable[[GraphsTuple], GraphsTuple])

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], list[Union[str, tuple[str, ...]]]])

  • out_field (Union[str, tuple[str, ...], list[Union[str, tuple[str, ...]]]])

  • sign (float)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

func: Callable[[GraphsTuple], GraphsTuple]#
name: str | None = None#
of: str | tuple[str, ...]#
out_field: str | tuple[str, ...] | list[str | tuple[str, ...]] = 'auto'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sign: float = 1.0#
wrt: str | tuple[str, ...] | list[str | tuple[str, ...]]#
class tensorial.gcnn.GraphLoss(label)[source]#

Bases: Module

Parameters:

label (str)

label()[source]#

Get a label for this loss function

Return type:

str

class tensorial.gcnn.GraphMetric(state=None)[source]#

Bases: Metric

Parameters:

state (Optional[Metric[TypeVar(OutT)]])

compute()[source]#

Compute the metric.

Return type:

TypeVar(OutT)

create(predictions, targets=None)[source]#

Create a new metric instance from data.

Parameters:
  • predictions (GraphsTuple)

  • targets (GraphsTuple | None)

Return type:

GraphMetric

property is_empty: bool#
mask_key: ClassVar[str | tuple[str, ...] | None] = 'auto'#
merge(other)[source]#

Merge the metric with data from another metric instance of the same type.

Parameters:

other (GraphMetric)

Return type:

GraphMetric

property metric: Metric[OutT] | None#
normalise_by: ClassVar[str | tuple[str, ...] | None] = None#
parent: ClassVar[Metric]#
pred_key: ClassVar[str | tuple[str, ...]]#
target_key: ClassVar[str | tuple[str, ...] | None] = None#
class tensorial.gcnn.IndexedLinear(irreps_out, num_types, index_field, field, out_field=None, name=None, parent=<flax.linen.module._Sentinel object>)[source]#

Bases: Module

Applies an indexed linear transformation to a field in a GraphsTuple.

This module performs a linear transformation on a per-element basis, where each element is routed through a specific linear layer determined by an associated index array. A separate set of learnable weights is maintained for each index value.

Variables:
  • irreps_out (str | e3j.Irreps) – The output irreducible representations of the linear transformation.

  • num_types (int) – Number of distinct index values, corresponding to the number of weight sets.

  • index_field (str) – Dot-separated path to the index array within the GraphsTuple.

  • field (str) – Dot-separated path to the input features within the GraphsTuple.

  • out_field (Optional[str]) – Dot-separated path where output features should be written. If None, overwrites field. name (str): Optional name for the internal Linear module.

Parameters:

graph (jraph.GraphsTuple) – A graph with fields specified by index_field and field.

Returns:

A new graph with updated features at out_field, where each input vector has been transformed by a linear layer corresponding to its associated index.

Return type:

jraph.GraphsTuple

Raises:
  • KeyError – If the specified field or index_field does not exist in the graph.

  • ValueError – If the index values exceed the range [0, num_types - 1].

Example

If graph.nodes contains input features and graph.nodes[“type”] contains integer indices in [0, num_types), the module applies a learned linear map per type:

IndexedLinear(“64x0e”, num_types=5, index_field=”nodes.type”, field=”nodes.feat”)

Each node’s “feat” will be transformed by a different Linear layer according to its “type”.

Parameters:
  • irreps_out (str | Irreps)

  • num_types (int)

  • index_field (str)

  • field (str)

  • out_field (str | None)

  • name (Optional[str])

  • parent (Union[Module, Scope, _Sentinel, None])

field: str#
index_field: str#
irreps_out: str | Irreps#
name: str | None = None#
num_types: int#
out_field: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class tensorial.gcnn.IndexedRescale(num_types, index_field, field, out_field=None, shifts=None, scales=None, rescale_init=<function variance_scaling.<locals>.init>, shift_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Applies a per-type affine transformation (scale and shift) to a specified field in a graph.

Each input is scaled and shifted based on an associated index (e.g. atomic or node type). The transformation is of the form: output = input * scale + shift, where both scale and shift are either learnable parameters or provided constants, indexed by the value in index_field.

This is typically used to normalize or denormalize features like node energies, depending on the type of node or atom.

Variables:
  • num_types (int) – Number of unique types (i.e. distinct values in index_field). Determines the number of learnable scale and shift parameters.

  • index_field (str) – Path (e.g. “nodes.type”) to the array of indices used to select the scale and shift for each input.

  • field (str) – Path to the input field to be rescaled.

  • out_field (Optional[str]) – Path to the output field. If None, the result is written to field.

  • shifts (Optional[ArrayLike]) – Optional constant shift values of shape (num_types,). If None, the shifts are learned parameters initialized with shift_init.

  • scales (Optional[ArrayLike]) – Optional constant scale values of shape (num_types, 1). If None, the scales are learned parameters initialized with rescale_init.

  • rescale_init (Initializer) – Initializer for learnable scale parameters.

  • shift_init (Initializer) – Initializer for learnable shift parameters.

Returns:

A new graph with the specified field transformed and stored at

out_field.

Return type:

jraph.GraphsTuple

Raises:

ValueError – If the number of types does not match the shape of provided scales or shifts.

Notes

  • Supports e3nn_jax.IrrepsArray input and preserves irreps metadata.

  • Uses jax.vmap internally for efficiency across nodes.

Parameters:
  • num_types (int)

  • index_field (str)

  • field (str)

  • out_field (str | None)

  • shifts (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, None])

  • scales (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, None])

  • rescale_init (Union[Initializer, Callable[..., Any]])

  • shift_init (Union[Initializer, Callable[..., Any]])

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

field: str#
index_field: str#
name: str | None = None#
num_types: int#
out_field: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
rescale_init(shape: Sequence[int | Any], dtype: Any | None = None, out_sharding: NamedSharding | PartitionSpec | None = None) Array#
Parameters:
  • key (Array)

  • shape (Sequence[Union[int, Any]])

  • dtype (Any | None)

  • out_sharding (NamedSharding | PartitionSpec | None)

Return type:

Array

scales: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

shift_init(shape, dtype=None, out_sharding=None)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
Parameters:
  • key (Array)

  • shape (Sequence[Union[int, Any]])

  • dtype (Any | None)

  • out_sharding (NamedSharding | PartitionSpec | None)

Return type:

Array

shifts: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None#
class tensorial.gcnn.Jacfwd(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Parameters:
  • func (Callable[[GraphsTuple], GraphsTuple])

  • of (Union[str, tuple[str, ...]])

  • wrt (str | Sequence[Union[str, tuple[str, ...]]])

  • out_field (str | Sequence[str])

  • sign (float)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

func: Callable[[GraphsTuple], GraphsTuple]#
name: str | None = None#
of: str | tuple[str, ...]#
out_field: str | Sequence[str] = 'auto'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sign: float = 1.0#
wrt: str | Sequence[str | tuple[str, ...]]#
class tensorial.gcnn.Jacobian(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Parameters:
  • func (Callable[[GraphsTuple], GraphsTuple])

  • of (Union[str, tuple[str, ...]])

  • wrt (str | Sequence[Union[str, tuple[str, ...]]])

  • out_field (str | Sequence[str])

  • sign (float)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

func: Callable[[GraphsTuple], GraphsTuple]#
name: str | None = None#
of: str | tuple[str, ...]#
out_field: str | Sequence[str] = 'auto'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sign: float = 1.0#
wrt: str | Sequence[str | tuple[str, ...]]#
class tensorial.gcnn.Loss(loss_fn, targets, predictions=None, *, reduction='mean', label=None, mask_field=None)[source]#

Bases: GraphLoss

Simple loss function that passes values from the graph to a function taking numerical values such as optax losses

Parameters:
  • loss_fn (str | Callable[[Array, Array], Array])

  • targets (str)

  • predictions (str | None)

  • reduction (Literal['sum', 'mean'])

  • label (str)

  • mask_field (str | None)

Initializes the loss function wrapper with specified parameters.

This constructor sets up a loss function wrapper that can be used to compute loss values based on target and prediction fields. It supports various loss functions and provides options for reduction and masking.

Parameters:
  • loss_fn – The loss function to use, either as a string identifier or a callable implementing the PureLossFn protocol.

  • targets – The path to the target field in the data structure.

  • predictions – The path to the prediction field in the data structure. If None, the target field path will be used.

  • reduction – The reduction method to apply to the computed losses. Either “sum” or “mean”.

  • label – The label to use for the loss function. If None, the prediction field path will be used as the label.

  • mask_field – The path to the mask field in the data structure. If None, no masking will be applied.

  • loss_fn (str | Callable[[Array, Array], Array])

  • targets (str)

  • predictions (str | None)

  • reduction (Literal['sum', 'mean'])

  • label (str)

  • mask_field (str | None)

Raises:
  • ValueError – If the reduction method is not “sum” or “mean”.

  • TypeError – If the loss function is not a valid string or callable.

class tensorial.gcnn.NequipLayer(irreps_out, invariant_layers=1, invariant_neurons=8, radial_num_layers=1, radial_num_neurons=8, radial_activation='swish', avg_num_neighbours=1.0, activations=FrozenDict({     e: 'silu', o: 'tanh', }), skip_connection=True, num_species=1, interaction_block=None, resnet=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

NequIP convolution layer.

Implementation based on: mir-group/nequip

Parameters:
  • irreps_out (Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, Union[None, Irrep, MulIrrep, str, Irreps, Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]]])

  • invariant_layers (int)

  • invariant_neurons (int)

  • radial_num_layers (int)

  • radial_num_neurons (int)

  • radial_activation (str | Callable[[Array], Array])

  • avg_num_neighbours (float | dict[int, float])

  • activations (str | Mapping[str, str | Callable[[Array], Array]])

  • skip_connection (bool)

  • num_species (int)

  • interaction_block (Callable)

  • resnet (bool)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

activations: str | Mapping[str, str | Callable[[Array], Array]] = FrozenDict({     e: 'silu',     o: 'tanh', })#
avg_num_neighbours: float | dict[int, float] = 1.0#
interaction_block: Callable = None#
invariant_layers: int = 1#
invariant_neurons: int = 8#
irreps_out: None | Irrep | MulIrrep | str | Irreps | Sequence[str | Irrep | MulIrrep | tuple[int, None | Irrep | MulIrrep | str | Irreps | Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]#
name: str | None = None#
node_features_field = 'features'#
num_species: int = 1#
parent: Module | Scope | _Sentinel | None = None#
radial_activation: str | Callable[[Array], Array] = 'swish'#
radial_num_layers: int = 1#
radial_num_neurons: int = 8#
resnet: bool = False#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

skip_connection: bool = True#
class tensorial.gcnn.NodewiseDecoding(attrs, in_field='attributes', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Decode the direct sum of irreps stored in the in_field and store each tensor as a node value with key coming from the attrs.

Parameters:
  • attrs (IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps])

  • in_field (str)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

attrs: IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps]#
in_field: str = 'attributes'#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class tensorial.gcnn.NodewiseEmbedding(attrs, out_field='attributes', node_shape_from='positions', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Take the attributes in the nodes dictionary given by attrs, embed them, and store the results as a direct sum of irreps in the out_field.

Parameters:
  • attrs (IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps])

  • out_field (str)

  • node_shape_from (str | None)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

attrs: IrrepsObj | dict[str, Attr | IrrepsObj | type | dict | FrozenDict | Irreps]#
name: str | None = None#
node_shape_from: str | None = 'positions'#
out_field: str = 'attributes'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
tensorial.gcnn.NodewiseEncoding#

alias of NodewiseEmbedding

class tensorial.gcnn.NodewiseLinear(irreps_out, irreps_in=None, field='features', out_field='features', num_types=None, types_field=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Nodewise linear operation

Parameters:
  • irreps_out (str | Irreps)

  • irreps_in (Irreps | None)

  • field (str)

  • out_field (str | None)

  • num_types (int | None)

  • types_field (str | None)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

field: str = 'features'#
irreps_in: Irreps | None = None#
irreps_out: str | Irreps#
name: str | None = None#
num_types: int | None = None#
out_field: str | None = 'features'#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

types_field: str | None = None#
class tensorial.gcnn.NodewiseReduce(field, out_field=None, reduce='sum', average_num_atoms=None, as_array=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Applies a reduction operation over node features and stores the result in the graph globals.

This module reduces a specified field in the graph’s node features across all nodes (within each graph if batched) using a specified reduction operation (sum, mean, or normalized_sum). The result is written to the globals field of the GraphsTuple.

Variables:
  • field (str) – Path to the node field to reduce, e.g. “energy” or “features.energy”.

  • out_field (Optional[str]) – Path to the output global field. If None, defaults to “<reduce>_<field>” under globals.

  • reduce (str) – Reduction operation to apply. Must be one of “sum”, “mean”, or “normalized_sum”.

  • average_num_atoms (float) – Required if reduce is “normalized_sum”. Used to scale the result by average_num_atoms ** -0.5.

Raises:
  • ValueError – If reduce is not one of the allowed options.

  • ValueError – If reduce == “normalized_sum” but average_num_atoms is not provided.

Returns:

A new GraphsTuple with the reduced value written to the specified globals field.

Parameters:
  • field (str)

  • out_field (str | None)

  • reduce (str)

  • average_num_atoms (float)

  • as_array (bool)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

as_array: bool = False#
average_num_atoms: float = None#
field: str#
name: str | None = None#
out_field: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
reduce: str = 'sum'#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class tensorial.gcnn.RadialBasisEdgeEmbedding(field='edge_lengths', out_field='radial_embeddings', num_basis=8, r_max=4.0, envelope=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Parameters:
  • field (str)

  • out_field (str)

  • num_basis (int)

  • r_max (float)

  • envelope (bool | Callable)

  • parent (Union[Module, Scope, _Sentinel, None])

  • name (Optional[str])

envelope: bool | Callable = False#
field: str = 'edge_lengths'#
name: str | None = None#
num_basis: int = 8#
out_field: str = 'radial_embeddings'#
parent: Module | Scope | _Sentinel | None = None#
r_max: float = 4.0#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

tensorial.gcnn.RadialBasisEdgeEncoding#

alias of RadialBasisEdgeEmbedding

class tensorial.gcnn.Rescale(shift_fields=(), scale_fields=(), shift=0.0, scale=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Applies constant rescaling and/or shifting to fields in a jraph.GraphsTuple.

This module modifies specified fields in the graph — which may be located in the nodes, edges, or globals — by multiplying them with a scalar factor (scale) and/or adding a constant offset (shift). This is useful for normalizing or denormalizing values, or applying consistent physical unit conversions.

Both scale_fields and shift_fields may be either a single string (e.g. “nodes.energy”) or a sequence of path strings. Missing fields are ignored silently.

Example usage:#

>>> Rescale(shift_fields='nodes.energy', shift=12.5)
shifts the energy stored in each node by 12.5.
>>> Rescale(scale_fields=['globals.volume'], scale=1e-3)
rescales the global volume by 1e-3.

Attributes:#

shift_fieldsstr | Sequence[Hashable]

Path(s) to the fields to which a constant shift should be applied.

scale_fieldsstr | Sequence[Hashable]

Path(s) to the fields to which a constant scale should be applied.

shiftjax.Array

Scalar constant to be added to all values in shift_fields. Defaults to 0.0.

scalejax.Array

Scalar constant to multiply all values in scale_fields. Defaults to 1.0.

Notes:#

  • Fields that are not found in the graph are skipped silently.

  • If a global field is shifted, a warning is logged that the field will no longer be size extensive with respect to the number of nodes or edges.

type shift_fields:

str | Sequence[Hashable]

param shift_fields:

type scale_fields:

str | Sequence[Hashable]

param scale_fields:

type shift:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

param shift:

type scale:

Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray]

param scale:

type parent:

Union[Module, Scope, _Sentinel, None]

param parent:

type name:

Optional[str]

param name:

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0#
scale_fields: str | Sequence[Hashable] = ()#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

shift: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0#
shift_fields: str | Sequence[Hashable] = ()#
class tensorial.gcnn.WeightedLoss(loss_fns, weights=None)[source]#

Bases: GraphLoss

Parameters:
  • loss_fns (Sequence[GraphLoss])

  • weights (Sequence[float] | None)

A weighted combination of multiple graph loss functions.

This class combines multiple graph loss functions with specified weights to create a composite loss function. The weights determine the contribution of each individual loss function to the final combined loss.

Parameters:
  • loss_fns – Sequence of graph loss functions to combine.

  • weights – Sequence of weights for each loss function. If None, all weights are set to 1.0.

  • loss_fns (Sequence[GraphLoss])

  • weights (Sequence[float] | None)

Raises:

ValueError – If any element in loss_fns is not a subclass of GraphLoss, or if the number of weights does not match the number of loss functions.

loss_with_contributions(predictions, target)[source]#
Parameters:
  • predictions (GraphsTuple)

  • target (GraphsTuple)

Return type:

tuple[Array, dict[str, float]]

property weights#
tensorial.gcnn.adapt(fn, *args, outs=(), return_graphs=False, **keywords)[source]#

Given a graph function, this will return a function that takes a graph as the first argument followed by positional arguments that will be mapped to the fields given by ins. Output paths can optionally be specified with outs which, if supplied, will make the function return one or more values from the graph as returned by fn.

Parameters:
  • fn – the graph function

  • *args – the input paths

  • outs – the output paths

  • return_graphs – if True and outs is specified, this will return a tuple containing the output graph followed by the values at outs

  • fn (ExGraphFunction)

  • args (Union[str, tuple[str, ...]])

  • outs (Sequence[Union[str, tuple[str, ...]]])

  • return_graphs (bool)

  • keywords (Union[str, tuple[str, ...]])

Return type:

TransformedGraphFunction

Returns:

a function that wraps fn with the above properties

tensorial.gcnn.diff(*func_of, wrt, out=None, scale=1.0, at=None, return_graph=False)[source]#

Constructs a JAX-compatible evaluator for computing single or multiple derivatives of a scalar function (e.g., energy) defined over a Graph.

This function acts as a factory, routing the request to create either a SingleDerivative or MultiDerivative object based on the wrt argument.

The key feature of this function is the use of string-based tensor index notation (e.g., ‘nodes.positions:Iγ’) for specifying differentiation targets and output shape.

Parameters:
  • *func_of

    Either (func), where ‘func’ is the energy function, or (func, of), where ‘of’ is the scalar entry (e.g., ‘globals.energy’) to differentiate. - func (Callable): The function (Graph -> Graph) whose output is differentiated. - of (GraphEntrySpecLike, optional): Specifies the scalar entry within the

    output graph of ‘func’ to differentiate. Defaults to the sole scalar output if omitted.

  • wrt (GraphEntrySpecLike | Sequence[GraphEntrySpecLike]) –

    The input entries of the graph with respect to which the derivative is taken. This must be a string or a sequence of strings, specifying the index notation for the input: - Example: ‘nodes.positions:Iγ’ means differentiate w.r.t. the $gamma$

    component of the position of node $I$.

  • out (GraphEntrySpecLike, optional) –

    The index notation for the desired output tensor shape. This string defines the contraction of the indices from ‘wrt’. - Example: If wrt=[‘field:α’, ‘field:β’, ‘positions:Iγ’], out=’:Iγαβ’

    specifies a rank-4 tensor output with indices $I, gamma, lpha, eta$. If omitted, the indices are concatenated in the order they appear in ‘wrt’.

  • scale (float) – A scalar factor to multiply the final derivative result by. Defaults to 1.0.

  • at (dict | None) – A dictionary mapping GraphEntrySpecLike strings (without indices) to jax.numpy arrays, specifying the value at which to evaluate the derivative for those entries. These entries will be held constant. - Example: {‘globals.electric_field’: jnp.zeros(3)}

  • return_graph (bool) – If True, the derivative tensor is packaged into a new Graph object under the name specified by the ‘out’ argument. If False, the function returns the raw derivative tensor. Defaults to False.

Returns:

A callable object that takes a Graph and returns the computed

derivative tensor (or a Graph containing it).

Return type:

Evaluator

Raises:

TypeError – If the arguments do not conform to the expected types.

Note

The index notation used in ‘wrt’ and ‘out’ must adhere to the library’s conventions for Graph entry keys and indices. For multi-derivatives, the number of unique indices in ‘out’ must match the number of indices in ‘wrt’.

Parameters:
tensorial.gcnn.grad(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.graph_from_points(pos, r_max, *, fractional_positions=False, self_interaction=True, strict_self_interaction=False, cell=None, pbc=None, nodes=None, edges=None, graph_globals=None, np_=<module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/tensorial/envs/stable/lib/python3.13/site-packages/numpy/__init__.py'>)[source]#

Create a jraph Graph from a set of atomic positions and other related data.

Parameters:
  • pos – a [N, 3] array of atomic positions

  • r_max – the cutoff radius to use for identifying neighbours

  • fractional_positions – if True, pos are interpreted as fractional positions

  • self_interaction – if True, edges are created between an atom and itself in other unit cells

  • strict_self_interaction – if True, edges are created between an atom and itself within the central unit cell

  • cell – a [3, 3] array of unit cell vectors (in row-major format)

  • pbc – a bool of a sequence of three `bool`s indicating whether the space is periodic in x, y, z directions

  • nodes – a dictionary containing additional data relating to each node, it should contain arrays of shape [N, …]

  • graph_globals – a dictionary containing additional global data

  • pos (Union[Float[Array, 'n_nodes 3'], Float[ndarray, 'n_nodes 3']])

  • r_max (Number)

  • fractional_positions (bool)

  • self_interaction (bool)

  • strict_self_interaction (bool)

  • cell (Optional[Union[Float[Array, '3 3'], Float[ndarray, '3 3']]])

  • pbc (Union[bool, Union[tuple[bool, bool, bool], Bool[Array, '3'], Bool[ndarray, '3'], Bool[TypedNdArray, '3']], None])

  • nodes (dict[str, Union[Num[Array, 'n_nodes *'], Num[ndarray, 'n_nodes *']]] | None)

  • edges (dict | None)

  • graph_globals (dict[str, Union[Array, ndarray]] | None)

Return type:

GraphsTuple

Returns:

the corresponding jraph Graph

tensorial.gcnn.graph_metric(metric, predictions, targets=None, mask='auto', normalise_by=None)[source]#
Parameters:
  • metric (str | Metric | type[Metric])

  • predictions (Union[str, tuple[str, ...]])

  • targets (Union[str, tuple[str, ...], None])

  • mask (Union[str, tuple[str, ...], Literal['auto'], None])

  • normalise_by (Union[str, tuple[str, ...], None])

Return type:

GraphMetric

tensorial.gcnn.hessian(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.jacfwd(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.jacobian(of, wrt, sign=1.0, has_aux=False)#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.jacrev(of, wrt, sign=1.0, has_aux=False)[source]#

Build a partially initialised Grad function whose only

Parameters:
  • kwargs – accepts any arguments that Grad does

  • of (Union[str, tuple[str, ...]])

  • wrt (Union[str, tuple[str, ...], Sequence[Union[str, tuple[str, ...]]]])

  • sign (float)

  • has_aux (bool)

Return type:

Callable[[Callable[[GraphsTuple], GraphsTuple]], Callable[[GraphsTuple, ...], GraphsTuple | PyTree | tuple[PyTree]]]

Returns:

the partially initialized Grad function

tensorial.gcnn.reduce(*args, **kwargs)[source]#
tensorial.gcnn.shape_check(func)[source]#

Decorator that will print to the logger any differences in either the keys present in the graph before and after the call, or any differences in their shapes.

This is super useful for diagnosing jax re-compilation issues.

Parameters:

func (Callable[[GraphsTuple], GraphsTuple])

Return type:

Callable[[GraphsTuple], GraphsTuple]

tensorial.gcnn.transform_fn(fn, *args, outs=(), return_graphs=False, **keywords)#

Given a graph function, this will return a function that takes a graph as the first argument followed by positional arguments that will be mapped to the fields given by ins. Output paths can optionally be specified with outs which, if supplied, will make the function return one or more values from the graph as returned by fn.

Parameters:
  • fn – the graph function

  • *args – the input paths

  • outs – the output paths

  • return_graphs – if True and outs is specified, this will return a tuple containing the output graph followed by the values at outs

  • fn (ExGraphFunction)

  • args (Union[str, tuple[str, ...]])

  • outs (Sequence[Union[str, tuple[str, ...]]])

  • return_graphs (bool)

  • keywords (Union[str, tuple[str, ...]])

Return type:

TransformedGraphFunction

Returns:

a function that wraps fn with the above properties

tensorial.gcnn.with_edge_vectors(graph, with_lengths=True, as_irreps_array=True)[source]#

Compute edge displacements for edge vectors in a graph.

This will add edge attributes corresponding that cache the vectors and displacements, meaning that they will not be recalculated if already done so.

Parameters:
  • graph (GraphsTuple)

  • with_lengths (bool)

  • as_irreps_array (bool | None)

Return type:

GraphsTuple