tensorial.training package#

Module contents#

class tensorial.training.ReaxModule(model, loss_fn, optimizer, scheduler=None, metrics=None, jit=True, donate_graph=False, output=('predictions', 'targets'))[source]#

Bases: Module[GraphsTuple, GraphsTuple]

Tensorial REAX module.

Parameters:
  • model (Module)

  • loss_fn (Callable[[GraphsTuple, GraphsTuple], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

Init function.

Parameters:
  • model (Module)

  • loss_fn (Callable[[GraphsTuple, GraphsTuple], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

static calculate_metrics(predictions, targets, metrics)[source]#
Parameters:
  • predictions (GraphsTuple)

  • targets (GraphsTuple)

  • metrics (dict[str, Metric | str])

Return type:

dict[str, Metric]

configure_model(_stage, batch, /)[source]#

Called at the beginning of each stage.

A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing.

Parameters:

_stage (Stage)

configure_optimizers()[source]#

Create the optimizer(s) to use during training.

property debug: bool#
on_before_optimizer_step(_optimizer, grad, /)[source]#

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If clipping gradients, the gradients will not have been clipped yet.

Parameters:
  • optimizer – Current optimizer being used.

  • grad – The gradients dictionary from JAX

  • _optimizer (Optimizer)

  • grad (dict[str, Any])

predict_step(batch, _batch_idx, /)[source]#

Make a model prediction and return the result.

Parameters:
  • batch (GraphsTuple)

  • _batch_idx (int)

Return type:

GraphsTuple

static step(params, inputs, _targets, model, loss_fn, metrics=None, output=())[source]#

Calculate loss and, optionally metrics.

Parameters:
  • params (PyTree)

  • inputs (GraphsTuple)

  • _targets (GraphsTuple)

  • model (Callable[[PyTree, GraphsTuple], GraphsTuple])

  • loss_fn (Callable)

  • metrics (MetricCollection | None)

  • output (tuple[str, ...])

Return type:

tuple[Array, dict]

test_step(batch, _batch_idx, /)[source]#

Test step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput | None

training_step(batch, _batch_idx, /)[source]#

Train step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput

validation_step(batch, _batch_idx, /)[source]#

Validate step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput | None