tensorial.training package#
Module contents#
- class tensorial.training.ReaxModule(model, loss_fn, optimizer, scheduler=None, metrics=None, jit=True, donate_graph=False, output=('predictions', 'targets'))[source]#
Bases:
Module[GraphsTuple,GraphsTuple]Tensorial REAX module.
- Parameters:
model (
Module)loss_fn (
Callable[[GraphsTuple,GraphsTuple],Array])optimizer (
GradientTransformation|Callable[[],GradientTransformation])scheduler (
Callable[[Union[Array,ndarray,bool,number,float,int]],Union[Array,ndarray,bool,number,float,int]] |None)metrics (
dict[str,Metric|str] |None)output (
Sequence[str] |None)
Init function.
- Parameters:
model (
Module)loss_fn (
Callable[[GraphsTuple,GraphsTuple],Array])optimizer (
GradientTransformation|Callable[[],GradientTransformation])scheduler (
Callable[[Union[Array,ndarray,bool,number,float,int]],Union[Array,ndarray,bool,number,float,int]] |None)metrics (
dict[str,Metric|str] |None)output (
Sequence[str] |None)
- static calculate_metrics(predictions, targets, metrics)[source]#
- Parameters:
predictions (
GraphsTuple)targets (
GraphsTuple)metrics (
dict[str,Metric|str])
- Return type:
dict[str,Metric]
- configure_model(_stage, batch, /)[source]#
Called at the beginning of each stage.
A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing.
- Parameters:
_stage (
Stage)
- property debug: bool#
- on_before_optimizer_step(_optimizer, grad, /)[source]#
Called before
optimizer.step().If using gradient accumulation, the hook is called once the gradients have been accumulated. See:
accumulate_grad_batches.If clipping gradients, the gradients will not have been clipped yet.
- predict_step(batch, _batch_idx, /)[source]#
Make a model prediction and return the result.
- Parameters:
batch (
GraphsTuple)_batch_idx (
int)
- Return type:
GraphsTuple
- static step(params, inputs, _targets, model, loss_fn, metrics=None, output=())[source]#
Calculate loss and, optionally metrics.
- Parameters:
params (
PyTree)inputs (
GraphsTuple)_targets (
GraphsTuple)model (
Callable[[PyTree,GraphsTuple],GraphsTuple])loss_fn (
Callable)metrics (
MetricCollection|None)output (
tuple[str,...])
- Return type:
tuple[Array,dict]
- test_step(batch, _batch_idx, /)[source]#
Test step.
- Parameters:
batch (
tuple[GraphsTuple,GraphsTuple])_batch_idx (
int)
- Return type:
StepOutput|None