tensorial.reaxkit package#

Subpackages#

Submodules#

tensorial.reaxkit.cli module#

Main CLI command

tensorial.reaxkit.cli.main_cli()[source]#

tensorial.reaxkit.config module#

tensorial.reaxkit.config.load_module(config_path='config.yaml', ckpt_path='params.ckpt', checkpointing=None, return_config=False)[source]#

Load a REAX module from a YAML configuration file, optionally restoring parameters from a checkpoint.

This function uses Hydra to instantiate a module from a config file and optionally loads learned parameters from a checkpoint file using the provided or default checkpointing mechanism. It can also return the full configuration object if needed.

Parameters:
  • config_path (str) – Path to the YAML configuration file specifying the model.

  • ckpt_path (str) – Path to the checkpoint file containing saved parameters.

  • checkpointing (Checkpointing, optional) – A Checkpointing instance to use for loading parameters. If None, the default REAX checkpointing is used.

  • return_config (bool) – If True, also return the loaded configuration object.

  • config_path (str)

  • ckpt_path (str)

  • checkpointing (Checkpointing)

  • return_config (bool)

Returns:

The instantiated REAX module, optionally accompanied by the loaded configuration.

Return type:

ReaxModule | tuple[ReaxModule, DictConfig]

tensorial.reaxkit.evaluate module#

tensorial.reaxkit.evaluate.main(cfg)[source]#

Main entry point for evaluation.

Parameters:
  • cfg – DictConfig configuration composed by Hydra.

  • cfg (DictConfig)

Return type:

None

tensorial.reaxkit.from_data module#

class tensorial.reaxkit.from_data.FromData(cfg, engine, *, rngs=None, dataloader=None, datamodule=None, dataloader_name='train', ignore_missing=True)[source]#

Bases: Stage

A trainer stage that will populate an OmegaConf dictionary with data statistics calculated from metrics.

Parameters:
  • cfg (DictConfig)

  • engine (Engine)

  • rngs (Rngs | None)

  • dataloader (DataLoader | None)

  • datamodule (DataModule | None)

  • dataloader_name (str | None)

  • ignore_missing (bool)

Populate a hydra configurations dictionary using calculated stats

Parameters:
  • cfg – the configuration dictionary

  • engine – the trainer strategy

  • rngs – the random number generator

  • dataloader – the dataloader to use

  • datamodule – if no dataloader is specified, a data module can be used instead

  • dataloader_name – the datamodule dataloader name

  • ignore_missing – if True, any data that is needed to calculate a metric but is missing will be ignored, and that metric will not be calculated

  • cfg (DictConfig)

  • engine (Engine)

  • rngs (Rngs | None)

  • dataloader (DataLoader | None)

  • datamodule (DataModule | None)

  • dataloader_name (str | None)

  • ignore_missing (bool)

property calculated: dict[str, Any]#

The dictionary holding the calculated statistics

property dataloader: DataLoader | None#
property dataloaders: DataLoader | None#

Dataloader function.

log(name, value, batch_size=None, prog_bar=False, logger=False, on_step=False, on_epoch=True)[source]#

Log a result while the stage is running.

Parameters:
  • name (str)

  • batch_size (int | None)

  • prog_bar (bool)

  • logger (bool)

Return type:

None

tensorial.reaxkit.keys module#

tensorial.reaxkit.train module#

tensorial.reaxkit.train.main(cfg)[source]#

Main entry point for training.

Parameters:
  • cfg – DictConfig configuration composed by Hydra.

  • cfg (DictConfig)

Return type:

float | None

Returns:

Optional[float] with optimized metric value.

tensorial.reaxkit.train.train(cfg)[source]#
Parameters:

cfg (DictConfig | dict)

Module contents#

The REAX toolkit contains a bunch of classes and function that help to build a full model training application using tensorial and REAX.

class tensorial.reaxkit.FromData(cfg, engine, *, rngs=None, dataloader=None, datamodule=None, dataloader_name='train', ignore_missing=True)[source]#

Bases: Stage

A trainer stage that will populate an OmegaConf dictionary with data statistics calculated from metrics.

Parameters:
  • cfg (DictConfig)

  • engine (Engine)

  • rngs (Rngs | None)

  • dataloader (DataLoader | None)

  • datamodule (DataModule | None)

  • dataloader_name (str | None)

  • ignore_missing (bool)

Populate a hydra configurations dictionary using calculated stats

Parameters:
  • cfg – the configuration dictionary

  • engine – the trainer strategy

  • rngs – the random number generator

  • dataloader – the dataloader to use

  • datamodule – if no dataloader is specified, a data module can be used instead

  • dataloader_name – the datamodule dataloader name

  • ignore_missing – if True, any data that is needed to calculate a metric but is missing will be ignored, and that metric will not be calculated

  • cfg (DictConfig)

  • engine (Engine)

  • rngs (Rngs | None)

  • dataloader (DataLoader | None)

  • datamodule (DataModule | None)

  • dataloader_name (str | None)

  • ignore_missing (bool)

property calculated: dict[str, Any]#

The dictionary holding the calculated statistics

property dataloader: DataLoader | None#
property dataloaders: DataLoader | None#

Dataloader function.

log(name, value, batch_size=None, prog_bar=False, logger=False, on_step=False, on_epoch=True)[source]#

Log a result while the stage is running.

Parameters:
  • name (str)

  • batch_size (int | None)

  • prog_bar (bool)

  • logger (bool)

Return type:

None

class tensorial.reaxkit.GraphParityPlotter(targets, predictions=None, save_dir='plots/', fit_plot_every=100, x_label=None, y_label=None)[source]#

Bases: ParityPlotter

Parameters:
  • targets (Union[str, tuple[str, ...]])

  • predictions (Union[str, tuple[str, ...], None])

  • save_dir (str | Path)

  • fit_plot_every (int)

  • x_label (str | None)

  • y_label (str | None)

class tensorial.reaxkit.MetricsPrinter(log_level=20, log_every=10)[source]#

Bases: ProgressBar

Prints all scalar metrics found in trainer.progress_bar_metrics with dynamic column width based on the title or the value.

Parameters:

log_every (int)

MIN_COLUMN_WIDTH = 5#
NUMBER_DECIMALS = 5#
disable()[source]#

You should provide a way to disable the progress bar.

Return type:

None

do_log(*args, **kwargs)[source]#
enable()[source]#

You should provide a way to enable the progress bar.

The Trainer will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.

Return type:

None

init_train_tqdm(stage)[source]#

Override this to customize the tqdm bar for training.

Parameters:

stage (EpochStage)

Return type:

tqdm

property is_disabled: bool#
property is_enabled: bool#
on_stage_end(trainer, stage, /)[source]#

The stage is ending.

Parameters:
  • trainer (Trainer)

  • stage (Stage)

Return type:

None

on_stage_iter_start(_trainer, stage, _step, /)[source]#

A stage is about to start an interation.

Parameters:
  • _trainer (Trainer)

  • stage (Stage)

  • _step (int)

Return type:

None

on_stage_start(_trainer, stage, /)[source]#

A trainer stage is starting.

Parameters:
  • _trainer (Trainer)

  • stage (Stage)

Return type:

None

class tensorial.reaxkit.ParityPlotter(save_dir='plots/', fit_plot_every=10, x_label='True Values (y)', y_label="Predicted Values (y')")[source]#

Bases: TrainerListener

A TrainerListener that collects true and predicted values to create a parity plot at the end of each stage (Train, Validate, Test, Predict).

Parameters:
  • save_dir (str | Path)

  • fit_plot_every (int)

  • x_label (str)

  • y_label (str)

get_target_predicted(batch, outputs)[source]#
Return type:

tuple[ndarray, ndarray]

on_fit_end(trainer, stage, /)[source]#

Fit has ended, plot the collected training data.

Parameters:
  • trainer (Trainer)

  • stage (Fit)

Return type:

None

on_predict_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#

The predict stage has just finished processing a batch.

Parameters:
  • _trainer (Trainer)

  • stage (Predict)

  • outputs (Any)

  • batch (Any)

  • _batch_idx (int)

Return type:

None

on_predict_end(trainer, stage, /)[source]#

Predict is ending, plot the collected prediction data.

Parameters:
  • trainer (Trainer)

  • stage (Predict)

Return type:

None

on_test_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#

The test stage has just finished processing a batch.

Parameters:
  • _trainer (Trainer)

  • stage (Test)

  • outputs (Any)

  • batch (Any)

  • _batch_idx (int)

Return type:

None

on_test_end(trainer, stage, /)[source]#

Test has ended, plot the collected test data.

Parameters:
  • trainer (Trainer)

  • stage (Test)

Return type:

None

on_train_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#

The training stage has just finished processing a batch.

Parameters:
  • _trainer (Trainer)

  • stage (Train)

  • outputs (Any)

  • batch (Any)

  • _batch_idx (int)

Return type:

None

on_train_end(trainer, stage, /)[source]#

Training is ending, plot the collected training data.

Parameters:
  • trainer (Trainer)

  • stage (Train)

on_validation_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#

The validation stage has just finished processing a batch.

Parameters:
  • _trainer (Trainer)

  • stage (Validate)

  • outputs (Any)

  • batch (Any)

  • _batch_idx (int)

Return type:

None

on_validation_end(trainer, stage, /)[source]#

Validation has ended, plot the collected validation data.

Parameters:
  • trainer (Trainer)

  • stage (Validate)

Return type:

None

reset()[source]#
class tensorial.reaxkit.ReaxModule(model, loss_fn, optimizer, scheduler=None, metrics=None, jit=True, donate_graph=False, output=('predictions', 'targets'))[source]#

Bases: Module[GraphsTuple, GraphsTuple]

Tensorial REAX module.

Parameters:
  • model (Module)

  • loss_fn (Callable[[GraphsTuple, GraphsTuple], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

Init function.

Parameters:
  • model (Module)

  • loss_fn (Callable[[GraphsTuple, GraphsTuple], Array])

  • optimizer (GradientTransformation | Callable[[], GradientTransformation])

  • scheduler (Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] | None)

  • metrics (dict[str, Metric | str] | None)

  • output (Sequence[str] | None)

static calculate_metrics(predictions, targets, metrics)[source]#
Parameters:
  • predictions (GraphsTuple)

  • targets (GraphsTuple)

  • metrics (dict[str, Metric | str])

Return type:

dict[str, Metric]

configure_model(_stage, batch, /)[source]#

Called at the beginning of each stage.

A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing.

Parameters:

_stage (Stage)

configure_optimizers()[source]#

Create the optimizer(s) to use during training.

property debug: bool#
on_before_optimizer_step(_optimizer, grad, /)[source]#

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If clipping gradients, the gradients will not have been clipped yet.

Parameters:
  • optimizer – Current optimizer being used.

  • grad – The gradients dictionary from JAX

  • _optimizer (Optimizer)

  • grad (dict[str, Any])

predict_step(batch, _batch_idx, /)[source]#

Make a model prediction and return the result.

Parameters:
  • batch (GraphsTuple)

  • _batch_idx (int)

Return type:

GraphsTuple

static step(params, inputs, _targets, model, loss_fn, metrics=None, output=())[source]#

Calculate loss and, optionally metrics.

Parameters:
  • params (PyTree)

  • inputs (GraphsTuple)

  • _targets (GraphsTuple)

  • model (Callable[[PyTree, GraphsTuple], GraphsTuple])

  • loss_fn (Callable)

  • metrics (MetricCollection | None)

  • output (tuple[str, ...])

Return type:

tuple[Array, dict]

test_step(batch, _batch_idx, /)[source]#

Test step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput | None

training_step(batch, _batch_idx, /)[source]#

Train step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput

validation_step(batch, _batch_idx, /)[source]#

Validate step.

Parameters:
  • batch (tuple[GraphsTuple, GraphsTuple])

  • _batch_idx (int)

Return type:

StepOutput | None