tensorial.gcnn package#
Subpackages#
- tensorial.gcnn.atomic package
- tensorial.gcnn.data package
- tensorial.gcnn.experimental package
Submodules#
tensorial.gcnn.calc module#
- tensorial.gcnn.calc.cell_volume(cell_vectors, np_=None)[source]#
Computes the volume of a unit cell defined by its cell vectors.
The cell volume is calculated as the absolute value of the determinant of the matrix formed by the cell vectors. This is commonly used in crystallography and computational physics to determine the volume of a unit cell.
- Parameters:
- Return type:
Union[Array,ndarray]- Returns:
The volume of the unit cell as a scalar tensor.
- Raises:
LinAlgError – If the cell vectors are linearly dependent, resulting in a zero determinant.
tensorial.gcnn.derivatives module#
- class tensorial.gcnn.derivatives.Grad(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleThe Grad class computes gradients of graph-based functions with respect to specified graph attributes. It enables automatic differentiation of operations defined on graph structures, such as computing how changes in node positions affect edge lengths or other graph properties. The class supports both scalar and vector-valued gradients and integrates with JAX for efficient computation.
- Parameters:
func (
Callable[[GraphsTuple],GraphsTuple])of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],list[Union[str,tuple[str,...]]]])out_field (
Union[str,tuple[str,...],list[Union[str,tuple[str,...]]]])sign (
float)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- func: Callable[[GraphsTuple], GraphsTuple]#
- name: str | None = None#
- of: str | tuple[str, ...]#
- out_field: str | tuple[str, ...] | list[str | tuple[str, ...]] = 'auto'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- sign: float = 1.0#
- wrt: str | tuple[str, ...] | list[str | tuple[str, ...]]#
- class tensorial.gcnn.derivatives.Jacfwd(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module- Parameters:
func (
Callable[[GraphsTuple],GraphsTuple])of (
Union[str,tuple[str,...]])wrt (
str|Sequence[Union[str,tuple[str,...]]])out_field (
str|Sequence[str])sign (
float)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- func: Callable[[GraphsTuple], GraphsTuple]#
- name: str | None = None#
- of: str | tuple[str, ...]#
- out_field: str | Sequence[str] = 'auto'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- sign: float = 1.0#
- wrt: str | Sequence[str | tuple[str, ...]]#
- class tensorial.gcnn.derivatives.Jacobian(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module- Parameters:
func (
Callable[[GraphsTuple],GraphsTuple])of (
Union[str,tuple[str,...]])wrt (
str|Sequence[Union[str,tuple[str,...]]])out_field (
str|Sequence[str])sign (
float)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- func: Callable[[GraphsTuple], GraphsTuple]#
- name: str | None = None#
- of: str | tuple[str, ...]#
- out_field: str | Sequence[str] = 'auto'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- sign: float = 1.0#
- wrt: str | Sequence[str | tuple[str, ...]]#
- tensorial.gcnn.derivatives.grad(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.derivatives.hessian(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.derivatives.jacfwd(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.derivatives.jacobian(of, wrt, sign=1.0, has_aux=False)#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.derivatives.jacrev(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
tensorial.gcnn.keys module#
tensorial.gcnn.losses module#
- class tensorial.gcnn.losses.Loss(loss_fn, targets, predictions=None, *, reduction='mean', label=None, mask_field=None)[source]#
Bases:
GraphLossSimple loss function that passes values from the graph to a function taking numerical values such as optax losses
- Parameters:
loss_fn (
str|Callable[[Array,Array],Array])targets (
str)predictions (
str|None)reduction (
Literal['sum','mean'])label (
str)mask_field (
str|None)
Initializes the loss function wrapper with specified parameters.
This constructor sets up a loss function wrapper that can be used to compute loss values based on target and prediction fields. It supports various loss functions and provides options for reduction and masking.
- Parameters:
loss_fn¶ – The loss function to use, either as a string identifier or a callable implementing the PureLossFn protocol.
targets¶ – The path to the target field in the data structure.
predictions¶ – The path to the prediction field in the data structure. If None, the target field path will be used.
reduction¶ – The reduction method to apply to the computed losses. Either “sum” or “mean”.
label¶ – The label to use for the loss function. If None, the prediction field path will be used as the label.
mask_field¶ – The path to the mask field in the data structure. If None, no masking will be applied.
loss_fn (
str|Callable[[Array,Array],Array])targets (
str)predictions (
str|None)reduction (
Literal['sum','mean'])label (
str)mask_field (
str|None)
- Raises:
ValueError – If the reduction method is not “sum” or “mean”.
TypeError – If the loss function is not a valid string or callable.
- class tensorial.gcnn.losses.WeightedLoss(loss_fns, weights=None)[source]#
Bases:
GraphLoss- Parameters:
loss_fns (
Sequence[GraphLoss])weights (
Sequence[float] |None)
A weighted combination of multiple graph loss functions.
This class combines multiple graph loss functions with specified weights to create a composite loss function. The weights determine the contribution of each individual loss function to the final combined loss.
- Parameters:
- Raises:
ValueError – If any element in loss_fns is not a subclass of GraphLoss, or if the number of weights does not match the number of loss functions.
- loss_with_contributions(predictions, target)[source]#
- Parameters:
predictions (
GraphsTuple)target (
GraphsTuple)
- Return type:
tuple[Array,dict[str,float]]
- property weights#
tensorial.gcnn.metrics module#
- class tensorial.gcnn.metrics.GraphMetric(state=None)[source]#
Bases:
Metric- Parameters:
state (
Optional[Metric[TypeVar(OutT)]])
- create(predictions, targets=None)[source]#
Create a new metric instance from data.
- Parameters:
predictions (
GraphsTuple)targets (
GraphsTuple|None)
- Return type:
- property is_empty: bool#
- mask_key: ClassVar[str | tuple[str, ...] | None] = 'auto'#
- merge(other)[source]#
Merge the metric with data from another metric instance of the same type.
- Parameters:
other (
GraphMetric)- Return type:
- property metric: Metric[OutT] | None#
- normalise_by: ClassVar[str | tuple[str, ...] | None] = None#
- parent: ClassVar[Metric]#
- pred_key: ClassVar[str | tuple[str, ...]]#
- target_key: ClassVar[str | tuple[str, ...] | None] = None#
- tensorial.gcnn.metrics.graph_metric(metric, predictions, targets=None, mask='auto', normalise_by=None)[source]#
- Parameters:
metric (
str|Metric|type[Metric])predictions (
Union[str,tuple[str,...]])targets (
Union[str,tuple[str,...],None])mask (
Union[str,tuple[str,...],Literal['auto'],None])normalise_by (
Union[str,tuple[str,...],None])
- Return type:
tensorial.gcnn.random module#
- tensorial.gcnn.random.spatial_graph(rng_key, num_nodes=None, num_graphs=None, cutoff=0.4, nodes=None)[source]#
Create graph(s) with nodes that have random positions
- Parameters:
rng_key (
Array)num_nodes (
int)nodes (
dict[str,Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray],int],Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]]]] |None)
- Return type:
GraphsTuple|list[GraphsTuple]
tensorial.gcnn.typing module#
tensorial.gcnn.utils module#
- class tensorial.gcnn.utils.UpdateDict(updating)[source]#
Bases:
MutableMappingThis class can be used to make updates to a dictionary without modifying the passed dictionary. Once all the updates are made, a new dictionary that is the result of the modifications can be retrieved using the _asdict() method.
- Parameters:
updating (
dict)
- DELETED = ()#
Module contents#
- class tensorial.gcnn.EdgeVectors(as_irreps_arrays=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleCreate edge vectors from atomic positions. This will take into account the unit cell (if present)
- Parameters:
as_irreps_arrays (
bool)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- as_irreps_arrays: bool = False#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class tensorial.gcnn.EdgewiseDecoding(attrs, in_field='attributes', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleDecode the direct sum of irreps stored in the in_field and store each tensor as a node value with key coming from the attrs.
- Parameters:
- in_field: str = 'attributes'#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class tensorial.gcnn.EdgewiseEmbedding(attrs, out_field='attributes', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module- Parameters:
- name: str | None = None#
- out_field: str = 'attributes'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- tensorial.gcnn.EdgewiseEncoding#
alias of
EdgewiseEmbedding
- class tensorial.gcnn.EdgewiseLinear(irreps_out, irreps_in=None, field='features', out_field='features', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleEdgewise linear operation
- Parameters:
irreps_out (
str|Irreps)irreps_in (
Irreps|None)field (
str)out_field (
str|None)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- field: str = 'features'#
- irreps_in: Irreps | None = None#
- irreps_out: str | Irreps#
- name: str | None = None#
- out_field: str | None = 'features'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- class tensorial.gcnn.Grad(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleThe Grad class computes gradients of graph-based functions with respect to specified graph attributes. It enables automatic differentiation of operations defined on graph structures, such as computing how changes in node positions affect edge lengths or other graph properties. The class supports both scalar and vector-valued gradients and integrates with JAX for efficient computation.
- Parameters:
func (
Callable[[GraphsTuple],GraphsTuple])of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],list[Union[str,tuple[str,...]]]])out_field (
Union[str,tuple[str,...],list[Union[str,tuple[str,...]]]])sign (
float)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- func: Callable[[GraphsTuple], GraphsTuple]#
- name: str | None = None#
- of: str | tuple[str, ...]#
- out_field: str | tuple[str, ...] | list[str | tuple[str, ...]] = 'auto'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- sign: float = 1.0#
- wrt: str | tuple[str, ...] | list[str | tuple[str, ...]]#
- class tensorial.gcnn.GraphMetric(state=None)[source]#
Bases:
Metric- Parameters:
state (
Optional[Metric[TypeVar(OutT)]])
- create(predictions, targets=None)[source]#
Create a new metric instance from data.
- Parameters:
predictions (
GraphsTuple)targets (
GraphsTuple|None)
- Return type:
- property is_empty: bool#
- mask_key: ClassVar[str | tuple[str, ...] | None] = 'auto'#
- merge(other)[source]#
Merge the metric with data from another metric instance of the same type.
- Parameters:
other (
GraphMetric)- Return type:
- property metric: Metric[OutT] | None#
- normalise_by: ClassVar[str | tuple[str, ...] | None] = None#
- parent: ClassVar[Metric]#
- pred_key: ClassVar[str | tuple[str, ...]]#
- target_key: ClassVar[str | tuple[str, ...] | None] = None#
- class tensorial.gcnn.IndexedLinear(irreps_out, num_types, index_field, field, out_field=None, name=None, parent=<flax.linen.module._Sentinel object>)[source]#
Bases:
ModuleApplies an indexed linear transformation to a field in a GraphsTuple.
This module performs a linear transformation on a per-element basis, where each element is routed through a specific linear layer determined by an associated index array. A separate set of learnable weights is maintained for each index value.
- Variables:
irreps_out (str | e3j.Irreps) – The output irreducible representations of the linear transformation.
num_types (int) – Number of distinct index values, corresponding to the number of weight sets.
index_field (str) – Dot-separated path to the index array within the GraphsTuple.
field (str) – Dot-separated path to the input features within the GraphsTuple.
out_field (Optional[str]) – Dot-separated path where output features should be written. If None, overwrites field. name (str): Optional name for the internal Linear module.
- Parameters:
graph¶ (jraph.GraphsTuple) – A graph with fields specified by index_field and field.
- Returns:
A new graph with updated features at out_field, where each input vector has been transformed by a linear layer corresponding to its associated index.
- Return type:
jraph.GraphsTuple
- Raises:
KeyError – If the specified field or index_field does not exist in the graph.
ValueError – If the index values exceed the range [0, num_types - 1].
Example
If graph.nodes contains input features and graph.nodes[“type”] contains integer indices in [0, num_types), the module applies a learned linear map per type:
IndexedLinear(“64x0e”, num_types=5, index_field=”nodes.type”, field=”nodes.feat”)
Each node’s “feat” will be transformed by a different Linear layer according to its “type”.
- Parameters:
irreps_out (
str|Irreps)num_types (
int)index_field (
str)field (
str)out_field (
str|None)name (
Optional[str])parent (
Union[Module,Scope,_Sentinel,None])
- field: str#
- index_field: str#
- irreps_out: str | Irreps#
- name: str | None = None#
- num_types: int#
- out_field: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class tensorial.gcnn.IndexedRescale(num_types, index_field, field, out_field=None, shifts=None, scales=None, rescale_init=<function variance_scaling.<locals>.init>, shift_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleApplies a per-type affine transformation (scale and shift) to a specified field in a graph.
Each input is scaled and shifted based on an associated index (e.g. atomic or node type). The transformation is of the form: output = input * scale + shift, where both scale and shift are either learnable parameters or provided constants, indexed by the value in index_field.
This is typically used to normalize or denormalize features like node energies, depending on the type of node or atom.
- Variables:
num_types (int) – Number of unique types (i.e. distinct values in index_field). Determines the number of learnable scale and shift parameters.
index_field (str) – Path (e.g. “nodes.type”) to the array of indices used to select the scale and shift for each input.
field (str) – Path to the input field to be rescaled.
out_field (Optional[str]) – Path to the output field. If None, the result is written to field.
shifts (Optional[ArrayLike]) – Optional constant shift values of shape (num_types,). If None, the shifts are learned parameters initialized with shift_init.
scales (Optional[ArrayLike]) – Optional constant scale values of shape (num_types, 1). If None, the scales are learned parameters initialized with rescale_init.
rescale_init (Initializer) – Initializer for learnable scale parameters.
shift_init (Initializer) – Initializer for learnable shift parameters.
- Returns:
- A new graph with the specified field transformed and stored at
out_field.
- Return type:
jraph.GraphsTuple
- Raises:
ValueError – If the number of types does not match the shape of provided scales or shifts.
Notes
Supports e3nn_jax.IrrepsArray input and preserves irreps metadata.
Uses jax.vmap internally for efficiency across nodes.
- Parameters:
num_types (
int)index_field (
str)field (
str)out_field (
str|None)shifts (
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,None])scales (
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,None])rescale_init (
Union[Initializer,Callable[...,Any]])shift_init (
Union[Initializer,Callable[...,Any]])parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- field: str#
- index_field: str#
- name: str | None = None#
- num_types: int#
- out_field: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- rescale_init(shape: Sequence[int | Any], dtype: Any | None = None, out_sharding: NamedSharding | PartitionSpec | None = None) Array#
- Parameters:
key (
Array)shape (
Sequence[Union[int,Any]])dtype (
Any|None)out_sharding (
NamedSharding|PartitionSpec|None)
- Return type:
Array
- scales: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- shift_init(shape, dtype=None, out_sharding=None)#
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters:
key (
Array)shape (
Sequence[Union[int,Any]])dtype (
Any|None)out_sharding (
NamedSharding|PartitionSpec|None)
- Return type:
Array
- shifts: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None#
- class tensorial.gcnn.Jacfwd(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module- Parameters:
func (
Callable[[GraphsTuple],GraphsTuple])of (
Union[str,tuple[str,...]])wrt (
str|Sequence[Union[str,tuple[str,...]]])out_field (
str|Sequence[str])sign (
float)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- func: Callable[[GraphsTuple], GraphsTuple]#
- name: str | None = None#
- of: str | tuple[str, ...]#
- out_field: str | Sequence[str] = 'auto'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- sign: float = 1.0#
- wrt: str | Sequence[str | tuple[str, ...]]#
- class tensorial.gcnn.Jacobian(func, of, wrt, out_field='auto', sign=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module- Parameters:
func (
Callable[[GraphsTuple],GraphsTuple])of (
Union[str,tuple[str,...]])wrt (
str|Sequence[Union[str,tuple[str,...]]])out_field (
str|Sequence[str])sign (
float)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- func: Callable[[GraphsTuple], GraphsTuple]#
- name: str | None = None#
- of: str | tuple[str, ...]#
- out_field: str | Sequence[str] = 'auto'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- sign: float = 1.0#
- wrt: str | Sequence[str | tuple[str, ...]]#
- class tensorial.gcnn.Loss(loss_fn, targets, predictions=None, *, reduction='mean', label=None, mask_field=None)[source]#
Bases:
GraphLossSimple loss function that passes values from the graph to a function taking numerical values such as optax losses
- Parameters:
loss_fn (
str|Callable[[Array,Array],Array])targets (
str)predictions (
str|None)reduction (
Literal['sum','mean'])label (
str)mask_field (
str|None)
Initializes the loss function wrapper with specified parameters.
This constructor sets up a loss function wrapper that can be used to compute loss values based on target and prediction fields. It supports various loss functions and provides options for reduction and masking.
- Parameters:
loss_fn¶ – The loss function to use, either as a string identifier or a callable implementing the PureLossFn protocol.
targets¶ – The path to the target field in the data structure.
predictions¶ – The path to the prediction field in the data structure. If None, the target field path will be used.
reduction¶ – The reduction method to apply to the computed losses. Either “sum” or “mean”.
label¶ – The label to use for the loss function. If None, the prediction field path will be used as the label.
mask_field¶ – The path to the mask field in the data structure. If None, no masking will be applied.
loss_fn (
str|Callable[[Array,Array],Array])targets (
str)predictions (
str|None)reduction (
Literal['sum','mean'])label (
str)mask_field (
str|None)
- Raises:
ValueError – If the reduction method is not “sum” or “mean”.
TypeError – If the loss function is not a valid string or callable.
- class tensorial.gcnn.NequipLayer(irreps_out, invariant_layers=1, invariant_neurons=8, radial_num_layers=1, radial_num_neurons=8, radial_activation='swish', avg_num_neighbours=1.0, activations=FrozenDict({ e: 'silu', o: 'tanh', }), skip_connection=True, num_species=1, interaction_block=None, resnet=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleNequIP convolution layer.
Implementation based on: mir-group/nequip
- Parameters:
irreps_out (
Union[None,Irrep,MulIrrep,str,Irreps,Sequence[str|Irrep|MulIrrep|tuple[int,Union[None,Irrep,MulIrrep,str,Irreps,Sequence[str|Irrep|MulIrrep|tuple[int, IntoIrreps]]]]]])invariant_layers (
int)invariant_neurons (
int)radial_num_layers (
int)radial_num_neurons (
int)radial_activation (
str|Callable[[Array],Array])avg_num_neighbours (
float|dict[int,float])activations (
str|Mapping[str,str|Callable[[Array],Array]])skip_connection (
bool)num_species (
int)interaction_block (
Callable)resnet (
bool)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- activations: str | Mapping[str, str | Callable[[Array], Array]] = FrozenDict({ e: 'silu', o: 'tanh', })#
- avg_num_neighbours: float | dict[int, float] = 1.0#
- interaction_block: Callable = None#
- invariant_layers: int = 1#
- invariant_neurons: int = 8#
- irreps_out: None | Irrep | MulIrrep | str | Irreps | Sequence[str | Irrep | MulIrrep | tuple[int, None | Irrep | MulIrrep | str | Irreps | Sequence[str | Irrep | MulIrrep | tuple[int, IntoIrreps]]]]#
- name: str | None = None#
- node_features_field = 'features'#
- num_species: int = 1#
- parent: Module | Scope | _Sentinel | None = None#
- radial_activation: str | Callable[[Array], Array] = 'swish'#
- radial_num_layers: int = 1#
- radial_num_neurons: int = 8#
- resnet: bool = False#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- skip_connection: bool = True#
- class tensorial.gcnn.NodewiseDecoding(attrs, in_field='attributes', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleDecode the direct sum of irreps stored in the in_field and store each tensor as a node value with key coming from the attrs.
- Parameters:
- in_field: str = 'attributes'#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class tensorial.gcnn.NodewiseEmbedding(attrs, out_field='attributes', node_shape_from='positions', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleTake the attributes in the nodes dictionary given by attrs, embed them, and store the results as a direct sum of irreps in the out_field.
- Parameters:
- name: str | None = None#
- node_shape_from: str | None = 'positions'#
- out_field: str = 'attributes'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- tensorial.gcnn.NodewiseEncoding#
alias of
NodewiseEmbedding
- class tensorial.gcnn.NodewiseLinear(irreps_out, irreps_in=None, field='features', out_field='features', num_types=None, types_field=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleNodewise linear operation
- Parameters:
irreps_out (
str|Irreps)irreps_in (
Irreps|None)field (
str)out_field (
str|None)num_types (
int|None)types_field (
str|None)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- field: str = 'features'#
- irreps_in: Irreps | None = None#
- irreps_out: str | Irreps#
- name: str | None = None#
- num_types: int | None = None#
- out_field: str | None = 'features'#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- types_field: str | None = None#
- class tensorial.gcnn.NodewiseReduce(field, out_field=None, reduce='sum', average_num_atoms=None, as_array=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleApplies a reduction operation over node features and stores the result in the graph globals.
This module reduces a specified field in the graph’s node features across all nodes (within each graph if batched) using a specified reduction operation (sum, mean, or normalized_sum). The result is written to the globals field of the GraphsTuple.
- Variables:
field (str) – Path to the node field to reduce, e.g. “energy” or “features.energy”.
out_field (Optional[str]) – Path to the output global field. If None, defaults to “<reduce>_<field>” under globals.
reduce (str) – Reduction operation to apply. Must be one of “sum”, “mean”, or “normalized_sum”.
average_num_atoms (float) – Required if reduce is “normalized_sum”. Used to scale the result by average_num_atoms ** -0.5.
- Raises:
ValueError – If reduce is not one of the allowed options.
ValueError – If reduce == “normalized_sum” but average_num_atoms is not provided.
- Returns:
A new GraphsTuple with the reduced value written to the specified globals field.
- Parameters:
field (
str)out_field (
str|None)reduce (
str)average_num_atoms (
float)as_array (
bool)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- as_array: bool = False#
- average_num_atoms: float = None#
- field: str#
- name: str | None = None#
- out_field: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- reduce: str = 'sum'#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- class tensorial.gcnn.RadialBasisEdgeEmbedding(field='edge_lengths', out_field='radial_embeddings', num_basis=8, r_max=4.0, envelope=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module- Parameters:
field (
str)out_field (
str)num_basis (
int)r_max (
float)envelope (
bool|Callable)parent (
Union[Module,Scope,_Sentinel,None])name (
Optional[str])
- envelope: bool | Callable = False#
- field: str = 'edge_lengths'#
- name: str | None = None#
- num_basis: int = 8#
- out_field: str = 'radial_embeddings'#
- parent: Module | Scope | _Sentinel | None = None#
- r_max: float = 4.0#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- tensorial.gcnn.RadialBasisEdgeEncoding#
alias of
RadialBasisEdgeEmbedding
- class tensorial.gcnn.Rescale(shift_fields=(), scale_fields=(), shift=0.0, scale=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleApplies constant rescaling and/or shifting to fields in a jraph.GraphsTuple.
This module modifies specified fields in the graph — which may be located in the nodes, edges, or globals — by multiplying them with a scalar factor (scale) and/or adding a constant offset (shift). This is useful for normalizing or denormalizing values, or applying consistent physical unit conversions.
Both scale_fields and shift_fields may be either a single string (e.g. “nodes.energy”) or a sequence of path strings. Missing fields are ignored silently.
Example usage:#
>>> Rescale(shift_fields='nodes.energy', shift=12.5) shifts the energy stored in each node by 12.5.
>>> Rescale(scale_fields=['globals.volume'], scale=1e-3) rescales the global volume by 1e-3.
Attributes:#
- shift_fieldsstr | Sequence[Hashable]
Path(s) to the fields to which a constant shift should be applied.
- scale_fieldsstr | Sequence[Hashable]
Path(s) to the fields to which a constant scale should be applied.
- shiftjax.Array
Scalar constant to be added to all values in shift_fields. Defaults to 0.0.
- scalejax.Array
Scalar constant to multiply all values in scale_fields. Defaults to 1.0.
Notes:#
Fields that are not found in the graph are skipped silently.
If a global field is shifted, a warning is logged that the field will no longer be size extensive with respect to the number of nodes or edges.
- type shift_fields:
str|Sequence[Hashable]- param shift_fields:
- type scale_fields:
str|Sequence[Hashable]- param scale_fields:
- type shift:
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]- param shift:
- type scale:
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray]- param scale:
- type parent:
Union[Module,Scope,_Sentinel,None]- param parent:
- type name:
Optional[str]- param name:
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0#
- scale_fields: str | Sequence[Hashable] = ()#
- scope: Scope | None = None#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- shift: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0#
- shift_fields: str | Sequence[Hashable] = ()#
- class tensorial.gcnn.WeightedLoss(loss_fns, weights=None)[source]#
Bases:
GraphLoss- Parameters:
loss_fns (
Sequence[GraphLoss])weights (
Sequence[float] |None)
A weighted combination of multiple graph loss functions.
This class combines multiple graph loss functions with specified weights to create a composite loss function. The weights determine the contribution of each individual loss function to the final combined loss.
- Parameters:
- Raises:
ValueError – If any element in loss_fns is not a subclass of GraphLoss, or if the number of weights does not match the number of loss functions.
- loss_with_contributions(predictions, target)[source]#
- Parameters:
predictions (
GraphsTuple)target (
GraphsTuple)
- Return type:
tuple[Array,dict[str,float]]
- property weights#
- tensorial.gcnn.adapt(fn, *args, outs=(), return_graphs=False, **keywords)[source]#
Given a graph function, this will return a function that takes a graph as the first argument followed by positional arguments that will be mapped to the fields given by
ins. Output paths can optionally be specified withoutswhich, if supplied, will make the function return one or more values from the graph as returned byfn.- Parameters:
fn¶ – the graph function
*args¶ – the input paths
outs¶ – the output paths
return_graphs¶ – if True and
outsis specified, this will return a tuple containing the output graph followed by the values atoutsfn (
ExGraphFunction)args (
Union[str,tuple[str,...]])outs (
Sequence[Union[str,tuple[str,...]]])return_graphs (
bool)keywords (
Union[str,tuple[str,...]])
- Return type:
TransformedGraphFunction- Returns:
a function that wraps
fnwith the above properties
- tensorial.gcnn.diff(*func_of, wrt, out=None, scale=1.0, at=None, return_graph=False)[source]#
Constructs a JAX-compatible evaluator for computing single or multiple derivatives of a scalar function (e.g., energy) defined over a Graph.
This function acts as a factory, routing the request to create either a SingleDerivative or MultiDerivative object based on the wrt argument.
The key feature of this function is the use of string-based tensor index notation (e.g., ‘nodes.positions:Iγ’) for specifying differentiation targets and output shape.
- Parameters:
*func_of¶ –
Either (func), where ‘func’ is the energy function, or (func, of), where ‘of’ is the scalar entry (e.g., ‘globals.energy’) to differentiate. - func (Callable): The function (Graph -> Graph) whose output is differentiated. - of (GraphEntrySpecLike, optional): Specifies the scalar entry within the
output graph of ‘func’ to differentiate. Defaults to the sole scalar output if omitted.
wrt¶ (GraphEntrySpecLike | Sequence[GraphEntrySpecLike]) –
The input entries of the graph with respect to which the derivative is taken. This must be a string or a sequence of strings, specifying the index notation for the input: - Example: ‘nodes.positions:Iγ’ means differentiate w.r.t. the $gamma$
component of the position of node $I$.
out¶ (GraphEntrySpecLike, optional) –
The index notation for the desired output tensor shape. This string defines the contraction of the indices from ‘wrt’. - Example: If wrt=[‘field:α’, ‘field:β’, ‘positions:Iγ’], out=’:Iγαβ’
specifies a rank-4 tensor output with indices $I, gamma, lpha, eta$. If omitted, the indices are concatenated in the order they appear in ‘wrt’.
scale¶ (float) – A scalar factor to multiply the final derivative result by. Defaults to 1.0.
at¶ (dict | None) – A dictionary mapping GraphEntrySpecLike strings (without indices) to jax.numpy arrays, specifying the value at which to evaluate the derivative for those entries. These entries will be held constant. - Example: {‘globals.electric_field’: jnp.zeros(3)}
return_graph¶ (bool) – If True, the derivative tensor is packaged into a new Graph object under the name specified by the ‘out’ argument. If False, the function returns the raw derivative tensor. Defaults to False.
- Returns:
- A callable object that takes a Graph and returns the computed
derivative tensor (or a Graph containing it).
- Return type:
Evaluator- Raises:
TypeError – If the arguments do not conform to the expected types.
Note
The index notation used in ‘wrt’ and ‘out’ must adhere to the library’s conventions for Graph entry keys and indices. For multi-derivatives, the number of unique indices in ‘out’ must match the number of indices in ‘wrt’.
- Parameters:
wrt (
str|GraphEntrySpec|Sequence[str|GraphEntrySpec])out (
str|GraphEntrySpec)scale (
float)at (
dict|None)
- tensorial.gcnn.grad(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.graph_from_points(pos, r_max, *, fractional_positions=False, self_interaction=True, strict_self_interaction=False, cell=None, pbc=None, nodes=None, edges=None, graph_globals=None, np_=<module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/tensorial/envs/stable/lib/python3.13/site-packages/numpy/__init__.py'>)[source]#
Create a jraph Graph from a set of atomic positions and other related data.
- Parameters:
pos¶ – a [N, 3] array of atomic positions
r_max¶ – the cutoff radius to use for identifying neighbours
fractional_positions¶ – if
True, pos are interpreted as fractional positionsself_interaction¶ – if
True, edges are created between an atom and itself in other unit cellsstrict_self_interaction¶ – if
True, edges are created between an atom and itself within the central unit cellcell¶ – a [3, 3] array of unit cell vectors (in row-major format)
pbc¶ – a
boolof a sequence of three `bool`s indicating whether the space is periodic in x, y, z directionsnodes¶ – a dictionary containing additional data relating to each node, it should contain arrays of shape [N, …]
graph_globals¶ – a dictionary containing additional global data
pos (
Union[Float[Array, 'n_nodes 3'],Float[ndarray, 'n_nodes 3']])r_max (
Number)fractional_positions (
bool)self_interaction (
bool)strict_self_interaction (
bool)cell (
Optional[Union[Float[Array, '3 3'],Float[ndarray, '3 3']]])pbc (
Union[bool,Union[tuple[bool,bool,bool],Bool[Array, '3'],Bool[ndarray, '3'],Bool[TypedNdArray, '3']],None])nodes (
dict[str,Union[Num[Array, 'n_nodes *'],Num[ndarray, 'n_nodes *']]] |None)edges (
dict|None)graph_globals (
dict[str,Union[Array,ndarray]] |None)
- Return type:
GraphsTuple- Returns:
the corresponding jraph Graph
- tensorial.gcnn.graph_metric(metric, predictions, targets=None, mask='auto', normalise_by=None)[source]#
- Parameters:
metric (
str|Metric|type[Metric])predictions (
Union[str,tuple[str,...]])targets (
Union[str,tuple[str,...],None])mask (
Union[str,tuple[str,...],Literal['auto'],None])normalise_by (
Union[str,tuple[str,...],None])
- Return type:
- tensorial.gcnn.hessian(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.jacfwd(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.jacobian(of, wrt, sign=1.0, has_aux=False)#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.jacrev(of, wrt, sign=1.0, has_aux=False)[source]#
Build a partially initialised Grad function whose only
- Parameters:
kwargs¶ – accepts any arguments that Grad does
of (
Union[str,tuple[str,...]])wrt (
Union[str,tuple[str,...],Sequence[Union[str,tuple[str,...]]]])sign (
float)has_aux (
bool)
- Return type:
Callable[[Callable[[GraphsTuple],GraphsTuple]],Callable[[GraphsTuple,...],GraphsTuple|PyTree|tuple[PyTree]]]- Returns:
the partially initialized Grad function
- tensorial.gcnn.shape_check(func)[source]#
Decorator that will print to the logger any differences in either the keys present in the graph before and after the call, or any differences in their shapes.
This is super useful for diagnosing jax re-compilation issues.
- Parameters:
func (
Callable[[GraphsTuple],GraphsTuple])- Return type:
Callable[[GraphsTuple],GraphsTuple]
- tensorial.gcnn.transform_fn(fn, *args, outs=(), return_graphs=False, **keywords)#
Given a graph function, this will return a function that takes a graph as the first argument followed by positional arguments that will be mapped to the fields given by
ins. Output paths can optionally be specified withoutswhich, if supplied, will make the function return one or more values from the graph as returned byfn.- Parameters:
fn¶ – the graph function
*args¶ – the input paths
outs¶ – the output paths
return_graphs¶ – if True and
outsis specified, this will return a tuple containing the output graph followed by the values atoutsfn (
ExGraphFunction)args (
Union[str,tuple[str,...]])outs (
Sequence[Union[str,tuple[str,...]]])return_graphs (
bool)keywords (
Union[str,tuple[str,...]])
- Return type:
TransformedGraphFunction- Returns:
a function that wraps
fnwith the above properties
- tensorial.gcnn.with_edge_vectors(graph, with_lengths=True, as_irreps_array=True)[source]#
Compute edge displacements for edge vectors in a graph.
This will add edge attributes corresponding that cache the vectors and displacements, meaning that they will not be recalculated if already done so.
- Parameters:
graph (
GraphsTuple)with_lengths (
bool)as_irreps_array (
bool|None)
- Return type:
GraphsTuple