tensorial.training package#
Module contents#
- class tensorial.training.ReaxModule(model, loss_fn, optimizer, scheduler=None, metrics=None, jit=True, donate_graph=False, output=('predictions', 'targets'))[source]#
Bases:
Module[InputT,OutputT_co]Tensorial REAX module.
- Parameters:
model (
Module)loss_fn (
Callable[[TypeVar(OutputT_co, covariant=True),TypeVar(InputT)],Array])optimizer (
GradientTransformation|Callable[[],GradientTransformation])scheduler (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]] |None)metrics (
dict[str,Metric|str] |None)output (
Sequence[str] |None)
Init function.
- Parameters:
model (
Module)loss_fn (
Callable[[TypeVar(OutputT_co, covariant=True),TypeVar(InputT)],Array])optimizer (
GradientTransformation|Callable[[],GradientTransformation])scheduler (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]] |None)metrics (
dict[str,Metric|str] |None)output (
Sequence[str] |None)
- static calculate_metrics(predictions, targets, metrics)[source]#
- Parameters:
predictions (
TypeVar(OutputT_co, covariant=True))targets (
TypeVar(OutputT_co, covariant=True))metrics (
dict[str,Metric|str])
- Return type:
dict[str,Metric]
- configure_model(_stage, batch, /)[source]#
Called at the beginning of each stage.
A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing.
- Parameters:
_stage (
Stage)
- property debug: bool#
- on_before_optimizer_step(_optimizer, grad, /)[source]#
Called before
optimizer.step().If using gradient accumulation, the hook is called once the gradients have been accumulated. See:
accumulate_grad_batches.If clipping gradients, the gradients will not have been clipped yet.
- predict_step(batch, _batch_idx, /)[source]#
Make a model prediction and return the result.
- Parameters:
batch (
TypeVar(InputT))_batch_idx (
int)
- Return type:
TypeVar(OutputT_co, covariant=True)
- static step(params, inputs, _targets, model, loss_fn, metrics=None, output=())[source]#
Calculate loss and, optionally metrics.
- Parameters:
params (
PyTree)inputs (
TypeVar(InputT))_targets (
TypeVar(OutputT_co, covariant=True))model (
Callable[[PyTree,TypeVar(InputT)],TypeVar(OutputT_co, covariant=True)])loss_fn (
Callable[[TypeVar(OutputT_co, covariant=True),TypeVar(InputT)],Array])metrics (
MetricCollection|None)output (
tuple[str,...])
- Return type:
tuple[Array,dict]
- test_step(batch, _batch_idx, /)[source]#
Test step.
- Parameters:
batch (
tuple[TypeVar(InputT),TypeVar(OutputT_co, covariant=True)])_batch_idx (
int)
- Return type:
StepOutput|None