tensorial.gcnn.experimental package#

Submodules#

tensorial.gcnn.experimental.derivatives module#

class tensorial.gcnn.experimental.derivatives.GraphEntrySpec(key_path, indices)[source]#

Bases: object

A specification that identifies a particular entry in a hierarchical data structure (e.g., a PyTree) along with optional index labels used for differentiable computations.

Variables:
  • key_path (Optional[gcnn.typing.TreePath]) – A path to the target node in the PyTree, typically represented as a tuple of keys (e.g., strings or integers).

  • indices (Optional[str]) – A string representing symbolic indices, often used to annotate tensor dimensions for operations like differentiation.

Parameters:
  • key_path (Union[int, tuple[str, ...], None])

  • indices (str | None)

classmethod create(spec)[source]#
Parameters:

spec (str | GraphEntrySpec)

Return type:

GraphEntrySpec

index_union(other)[source]#
Parameters:

other (GraphEntrySpec)

Return type:

str | None

indices: str | None#
indices_intersection(other)[source]#
Parameters:

other (GraphEntrySpec)

Return type:

str

key_path: int | tuple[str, ...] | None#
property safe_indices: str#
class tensorial.gcnn.experimental.derivatives.MultiDerivative(parts)[source]#

Bases: Derivative

Parameters:

parts (tuple[SingleDerivative, ...])

property argnum_paths: dict[int, int]#
build_derivative_fn(func, return_graph, argnum)[source]#

Get evaluate function from derivative

Parameters:
  • func (DerivableGraphFunction)

  • return_graph (bool)

  • argnum (int)

Return type:

DerivableGraphFunction

classmethod create(of, wrt, out=None)[source]#
Parameters:
Return type:

MultiDerivative

property graph_tuple_paths: dict[tuple[str, ...], int]#
property of: GraphEntrySpec#

Derivative of

property out: GraphEntrySpec#

Derivative output

parts: tuple[SingleDerivative, ...]#
property paths: dict[int | tuple[str, ...], int]#
class tensorial.gcnn.experimental.derivatives.SingleDerivative(_of, _wrt, _out)[source]#

Bases: Derivative

Parameters:
property argnum_paths: dict[int, int]#
build_derivative_fn(func, return_graph, argnum)[source]#

Get evaluate function from derivative

Parameters:
  • func (DerivableGraphFunction)

  • return_graph (bool)

  • argnum (int)

Return type:

DerivableGraphFunction

classmethod create(of, wrt, out=None)[source]#
Parameters:
Return type:

SingleDerivative

property graph_tuple_paths: dict[tuple[str, ...], int]#
property of: GraphEntrySpec#

Derivative of

property out: GraphEntrySpec#

Derivative output

property wrt: GraphEntrySpec#

Derivative output

tensorial.gcnn.experimental.derivatives.diff(*func_of, wrt, of=None, out=None, scale=1.0, at=None, return_graph=False)[source]#

Constructs a JAX-compatible evaluator for computing single or multiple derivatives of a scalar function (e.g., energy) defined over a Graph.

This function acts as a factory, routing the request to create either a SingleDerivative or MultiDerivative object based on the wrt argument.

The key feature of this function is the use of string-based tensor index notation (e.g., ‘nodes.positions:Iγ’) for specifying differentiation targets and output shape.

Parameters:
  • *func_of

    Either (func), where ‘func’ is the energy function, or (func, of), where ‘of’ is the scalar entry (e.g., ‘globals.energy’) to differentiate. - func (Callable): The function (Graph -> Graph) whose output is differentiated. - of (GraphEntrySpecLike, optional): Specifies the scalar entry within the

    output graph of ‘func’ to differentiate. Defaults to the sole scalar output if omitted.

  • wrt (GraphEntrySpecLike | Sequence[GraphEntrySpecLike]) –

    The input entries of the graph with respect to which the derivative is taken. This must be a string or a sequence of strings, specifying the index notation for the input: - Example: ‘nodes.positions:Iγ’ means differentiate w.r.t. the $gamma$

    component of the position of node $I$.

  • out (GraphEntrySpecLike, optional) –

    The index notation for the desired output tensor shape. This string defines the contraction of the indices from ‘wrt’. - Example: If wrt=[‘field:α’, ‘field:β’, ‘positions:Iγ’], out=’:Iγαβ’

    specifies a rank-4 tensor output with indices $I, gamma, lpha, eta$. If omitted, the indices are concatenated in the order they appear in ‘wrt’.

  • scale (float) – A scalar factor to multiply the final derivative result by. Defaults to 1.0.

  • at (dict | None) – A dictionary mapping GraphEntrySpecLike strings (without indices) to jax.numpy arrays, specifying the value at which to evaluate the derivative for those entries. These entries will be held constant. - Example: {‘globals.electric_field’: jnp.zeros(3)}

  • return_graph (bool) – If True, the derivative tensor is packaged into a new Graph object under the name specified by the ‘out’ argument. If False, the function returns the raw derivative tensor. Defaults to False.

Returns:

A callable object that takes a Graph and returns the computed

derivative tensor (or a Graph containing it).

Return type:

Evaluator

Raises:

TypeError – If the arguments do not conform to the expected types.

Note

The index notation used in ‘wrt’ and ‘out’ must adhere to the library’s conventions for Graph entry keys and indices. For multi-derivatives, the number of unique indices in ‘out’ must match the number of indices in ‘wrt’.

Parameters:

tensorial.gcnn.experimental.utils module#

class tensorial.gcnn.experimental.utils.GraphMutator(graph)[source]#

Bases: object

Parameters:

graph (GraphsTuple)

delete(path)[source]#
Parameters:

path (str | tuple)

Return type:

GraphMutator

get()[source]#
Return type:

GraphsTuple

set(path, value)[source]#
Parameters:

path (str | tuple)

Return type:

GraphMutator

update(path, updates)[source]#
Parameters:
  • path (str | tuple)

  • updates (dict)

Return type:

GraphMutator

tensorial.gcnn.experimental.utils.update_graph(graph)[source]#
Parameters:

graph (GraphsTuple)

Return type:

GraphMutator

Module contents#

class tensorial.gcnn.experimental.GraphEntrySpec(key_path, indices)[source]#

Bases: object

A specification that identifies a particular entry in a hierarchical data structure (e.g., a PyTree) along with optional index labels used for differentiable computations.

Variables:
  • key_path (Optional[gcnn.typing.TreePath]) – A path to the target node in the PyTree, typically represented as a tuple of keys (e.g., strings or integers).

  • indices (Optional[str]) – A string representing symbolic indices, often used to annotate tensor dimensions for operations like differentiation.

Parameters:
  • key_path (Union[int, tuple[str, ...], None])

  • indices (str | None)

classmethod create(spec)[source]#
Parameters:

spec (str | GraphEntrySpec)

Return type:

GraphEntrySpec

index_union(other)[source]#
Parameters:

other (GraphEntrySpec)

Return type:

str | None

indices: str | None#
indices_intersection(other)[source]#
Parameters:

other (GraphEntrySpec)

Return type:

str

key_path: int | tuple[str, ...] | None#
property safe_indices: str#
class tensorial.gcnn.experimental.GraphMutator(graph)[source]#

Bases: object

Parameters:

graph (GraphsTuple)

delete(path)[source]#
Parameters:

path (str | tuple)

Return type:

GraphMutator

get()[source]#
Return type:

GraphsTuple

set(path, value)[source]#
Parameters:

path (str | tuple)

Return type:

GraphMutator

update(path, updates)[source]#
Parameters:
  • path (str | tuple)

  • updates (dict)

Return type:

GraphMutator

class tensorial.gcnn.experimental.MultiDerivative(parts)[source]#

Bases: Derivative

Parameters:

parts (tuple[SingleDerivative, ...])

property argnum_paths: dict[int, int]#
build_derivative_fn(func, return_graph, argnum)[source]#

Get evaluate function from derivative

Parameters:
  • func (DerivableGraphFunction)

  • return_graph (bool)

  • argnum (int)

Return type:

DerivableGraphFunction

classmethod create(of, wrt, out=None)[source]#
Parameters:
Return type:

MultiDerivative

property graph_tuple_paths: dict[tuple[str, ...], int]#
property of: GraphEntrySpec#

Derivative of

property out: GraphEntrySpec#

Derivative output

parts: tuple[SingleDerivative, ...]#
property paths: dict[int | tuple[str, ...], int]#
class tensorial.gcnn.experimental.SingleDerivative(_of, _wrt, _out)[source]#

Bases: Derivative

Parameters:
property argnum_paths: dict[int, int]#
build_derivative_fn(func, return_graph, argnum)[source]#

Get evaluate function from derivative

Parameters:
  • func (DerivableGraphFunction)

  • return_graph (bool)

  • argnum (int)

Return type:

DerivableGraphFunction

classmethod create(of, wrt, out=None)[source]#
Parameters:
Return type:

SingleDerivative

property graph_tuple_paths: dict[tuple[str, ...], int]#
property of: GraphEntrySpec#

Derivative of

property out: GraphEntrySpec#

Derivative output

property wrt: GraphEntrySpec#

Derivative output

tensorial.gcnn.experimental.diff(*func_of, wrt, of=None, out=None, scale=1.0, at=None, return_graph=False)[source]#

Constructs a JAX-compatible evaluator for computing single or multiple derivatives of a scalar function (e.g., energy) defined over a Graph.

This function acts as a factory, routing the request to create either a SingleDerivative or MultiDerivative object based on the wrt argument.

The key feature of this function is the use of string-based tensor index notation (e.g., ‘nodes.positions:Iγ’) for specifying differentiation targets and output shape.

Parameters:
  • *func_of

    Either (func), where ‘func’ is the energy function, or (func, of), where ‘of’ is the scalar entry (e.g., ‘globals.energy’) to differentiate. - func (Callable): The function (Graph -> Graph) whose output is differentiated. - of (GraphEntrySpecLike, optional): Specifies the scalar entry within the

    output graph of ‘func’ to differentiate. Defaults to the sole scalar output if omitted.

  • wrt (GraphEntrySpecLike | Sequence[GraphEntrySpecLike]) –

    The input entries of the graph with respect to which the derivative is taken. This must be a string or a sequence of strings, specifying the index notation for the input: - Example: ‘nodes.positions:Iγ’ means differentiate w.r.t. the $gamma$

    component of the position of node $I$.

  • out (GraphEntrySpecLike, optional) –

    The index notation for the desired output tensor shape. This string defines the contraction of the indices from ‘wrt’. - Example: If wrt=[‘field:α’, ‘field:β’, ‘positions:Iγ’], out=’:Iγαβ’

    specifies a rank-4 tensor output with indices $I, gamma, lpha, eta$. If omitted, the indices are concatenated in the order they appear in ‘wrt’.

  • scale (float) – A scalar factor to multiply the final derivative result by. Defaults to 1.0.

  • at (dict | None) – A dictionary mapping GraphEntrySpecLike strings (without indices) to jax.numpy arrays, specifying the value at which to evaluate the derivative for those entries. These entries will be held constant. - Example: {‘globals.electric_field’: jnp.zeros(3)}

  • return_graph (bool) – If True, the derivative tensor is packaged into a new Graph object under the name specified by the ‘out’ argument. If False, the function returns the raw derivative tensor. Defaults to False.

Returns:

A callable object that takes a Graph and returns the computed

derivative tensor (or a Graph containing it).

Return type:

Evaluator

Raises:

TypeError – If the arguments do not conform to the expected types.

Note

The index notation used in ‘wrt’ and ‘out’ must adhere to the library’s conventions for Graph entry keys and indices. For multi-derivatives, the number of unique indices in ‘out’ must match the number of indices in ‘wrt’.

Parameters:
tensorial.gcnn.experimental.update_graph(graph)[source]#
Parameters:

graph (GraphsTuple)

Return type:

GraphMutator