tensorial.training package#

Module contents#

class tensorial.training.ReaxModule(model, loss_fn, optimizer, scheduler=None, metrics=None, jit=True, donate_graph=False, output=('predictions', 'targets'))[source]#

Bases: Module[InputT, OutputT_co]

Tensorial REAX module.

Parameters:
  • model (Module)

  • loss_fn (Callable[[TypeVar(OutputT_co, covariant=True), TypeVar(InputT)], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

Init function.

Parameters:
  • model (Module)

  • loss_fn (Callable[[TypeVar(OutputT_co, covariant=True), TypeVar(InputT)], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

static calculate_metrics(predictions, targets, metrics)[source]#
Parameters:
  • predictions (TypeVar(OutputT_co, covariant=True))

  • targets (TypeVar(OutputT_co, covariant=True))

  • metrics (dict[str, Metric | str])

Return type:

dict[str, Metric]

configure_model(_stage, batch, /)[source]#

Called at the beginning of each stage.

A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing.

Parameters:

_stage (Stage)

configure_optimizers()[source]#

Create the optimizer(s) to use during training.

property debug: bool#
on_before_optimizer_step(_optimizer, grad, /)[source]#

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If clipping gradients, the gradients will not have been clipped yet.

Parameters:
  • optimizer – Current optimizer being used.

  • grad – The gradients dictionary from JAX

  • _optimizer (Optimizer)

  • grad (dict[str, Any])

predict_step(batch, _batch_idx, /)[source]#

Make a model prediction and return the result.

Parameters:
  • batch (TypeVar(InputT))

  • _batch_idx (int)

Return type:

TypeVar(OutputT_co, covariant=True)

static step(params, inputs, _targets, model, loss_fn, metrics=None, output=())[source]#

Calculate loss and, optionally metrics.

Parameters:
  • params (PyTree)

  • inputs (TypeVar(InputT))

  • _targets (TypeVar(OutputT_co, covariant=True))

  • model (Callable[[PyTree, TypeVar(InputT)], TypeVar(OutputT_co, covariant=True)])

  • loss_fn (Callable[[TypeVar(OutputT_co, covariant=True), TypeVar(InputT)], Array])

  • metrics (MetricCollection | None)

  • output (tuple[str, ...])

Return type:

tuple[Array, dict]

test_step(batch, _batch_idx, /)[source]#

Test step.

Parameters:
  • batch (tuple[TypeVar(InputT), TypeVar(OutputT_co, covariant=True)])

  • _batch_idx (int)

Return type:

StepOutput | None

training_step(batch, _batch_idx, /)[source]#

Train step.

Parameters:
  • batch (tuple[TypeVar(InputT), TypeVar(OutputT_co, covariant=True)])

  • _batch_idx (int)

Return type:

StepOutput

validation_step(batch, _batch_idx, /)[source]#

Validate step.

Parameters:
  • batch (tuple[TypeVar(InputT), TypeVar(OutputT_co, covariant=True)])

  • _batch_idx (int)

Return type:

StepOutput | None