tensorial.reaxkit package#
Subpackages#
Submodules#
tensorial.reaxkit.cli module#
Main CLI command
tensorial.reaxkit.config module#
- tensorial.reaxkit.config.load_module(config_path='config.yaml', ckpt_path='params.ckpt', checkpointing=None, return_config=False)[source]#
Load a REAX module from a YAML configuration file, optionally restoring parameters from a checkpoint.
This function uses Hydra to instantiate a module from a config file and optionally loads learned parameters from a checkpoint file using the provided or default checkpointing mechanism. It can also return the full configuration object if needed.
- Parameters:
config_path¶ (str) – Path to the YAML configuration file specifying the model.
ckpt_path¶ (str) – Path to the checkpoint file containing saved parameters.
checkpointing¶ (Checkpointing, optional) – A Checkpointing instance to use for loading parameters. If None, the default REAX checkpointing is used.
return_config¶ (bool) – If True, also return the loaded configuration object.
config_path (
str)ckpt_path (
str)checkpointing (
Checkpointing)return_config (
bool)
- Returns:
The instantiated REAX module, optionally accompanied by the loaded configuration.
- Return type:
ReaxModule|tuple[ReaxModule,DictConfig]
tensorial.reaxkit.evaluate module#
tensorial.reaxkit.from_data module#
- class tensorial.reaxkit.from_data.FromData(cfg, engine, *, rngs=None, dataloader=None, datamodule=None, dataloader_name='train', ignore_missing=True)[source]#
Bases:
StageA trainer stage that will populate an OmegaConf dictionary with data statistics calculated from metrics.
- Parameters:
cfg (
DictConfig)engine (
Engine)rngs (
Rngs|None)dataloader (
DataLoader|None)datamodule (
DataModule|None)dataloader_name (
str|None)ignore_missing (
bool)
Populate a hydra configurations dictionary using calculated stats
- Parameters:
cfg¶ – the configuration dictionary
engine¶ – the trainer strategy
rngs¶ – the random number generator
dataloader¶ – the dataloader to use
datamodule¶ – if no dataloader is specified, a data module can be used instead
dataloader_name¶ – the datamodule dataloader name
ignore_missing¶ – if True, any data that is needed to calculate a metric but is missing will be ignored, and that metric will not be calculated
cfg (
DictConfig)engine (
Engine)rngs (
Rngs|None)dataloader (
DataLoader|None)datamodule (
DataModule|None)dataloader_name (
str|None)ignore_missing (
bool)
- property calculated: dict[str, Any]#
The dictionary holding the calculated statistics
- property dataloader: DataLoader | None#
- property dataloaders: DataLoader | None#
Dataloader function.
tensorial.reaxkit.keys module#
tensorial.reaxkit.train module#
Module contents#
The REAX toolkit contains a bunch of classes and function that help to build a full model training application using tensorial and REAX.
- class tensorial.reaxkit.FromData(cfg, engine, *, rngs=None, dataloader=None, datamodule=None, dataloader_name='train', ignore_missing=True)[source]#
Bases:
StageA trainer stage that will populate an OmegaConf dictionary with data statistics calculated from metrics.
- Parameters:
cfg (
DictConfig)engine (
Engine)rngs (
Rngs|None)dataloader (
DataLoader|None)datamodule (
DataModule|None)dataloader_name (
str|None)ignore_missing (
bool)
Populate a hydra configurations dictionary using calculated stats
- Parameters:
cfg¶ – the configuration dictionary
engine¶ – the trainer strategy
rngs¶ – the random number generator
dataloader¶ – the dataloader to use
datamodule¶ – if no dataloader is specified, a data module can be used instead
dataloader_name¶ – the datamodule dataloader name
ignore_missing¶ – if True, any data that is needed to calculate a metric but is missing will be ignored, and that metric will not be calculated
cfg (
DictConfig)engine (
Engine)rngs (
Rngs|None)dataloader (
DataLoader|None)datamodule (
DataModule|None)dataloader_name (
str|None)ignore_missing (
bool)
- property calculated: dict[str, Any]#
The dictionary holding the calculated statistics
- property dataloader: DataLoader | None#
- property dataloaders: DataLoader | None#
Dataloader function.
- class tensorial.reaxkit.GraphParityPlotter(targets, predictions=None, save_dir='plots/', fit_plot_every=100, x_label=None, y_label=None)[source]#
Bases:
ParityPlotter- Parameters:
targets (
Union[str,tuple[str,...]])predictions (
Union[str,tuple[str,...],None])save_dir (
str|Path)fit_plot_every (
int)x_label (
str|None)y_label (
str|None)
- class tensorial.reaxkit.MetricsPrinter(log_level=20, log_every=10)[source]#
Bases:
ProgressBarPrints all scalar metrics found in trainer.progress_bar_metrics with dynamic column width based on the title or the value.
- Parameters:
log_every (
int)
- MIN_COLUMN_WIDTH = 5#
- NUMBER_DECIMALS = 5#
- enable()[source]#
You should provide a way to enable the progress bar.
The
Trainerwill call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.- Return type:
None
- init_train_tqdm(stage)[source]#
Override this to customize the tqdm bar for training.
- Parameters:
stage (
EpochStage)- Return type:
tqdm
- property is_disabled: bool#
- property is_enabled: bool#
- on_stage_end(trainer, stage, /)[source]#
The stage is ending.
- Parameters:
trainer (
Trainer)stage (
Stage)
- Return type:
None
- class tensorial.reaxkit.ParityPlotter(save_dir='plots/', fit_plot_every=10, x_label='True Values (y)', y_label="Predicted Values (y')")[source]#
Bases:
TrainerListenerA TrainerListener that collects true and predicted values to create a parity plot at the end of each stage (Train, Validate, Test, Predict).
- Parameters:
save_dir (
str|Path)fit_plot_every (
int)x_label (
str)y_label (
str)
- on_fit_end(trainer, stage, /)[source]#
Fit has ended, plot the collected training data.
- Parameters:
trainer (
Trainer)stage (
Fit)
- Return type:
None
- on_predict_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#
The predict stage has just finished processing a batch.
- Parameters:
_trainer (
Trainer)stage (
Predict)outputs (
Any)batch (
Any)_batch_idx (
int)
- Return type:
None
- on_predict_end(trainer, stage, /)[source]#
Predict is ending, plot the collected prediction data.
- Parameters:
trainer (
Trainer)stage (
Predict)
- Return type:
None
- on_test_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#
The test stage has just finished processing a batch.
- Parameters:
_trainer (
Trainer)stage (
Test)outputs (
Any)batch (
Any)_batch_idx (
int)
- Return type:
None
- on_test_end(trainer, stage, /)[source]#
Test has ended, plot the collected test data.
- Parameters:
trainer (
Trainer)stage (
Test)
- Return type:
None
- on_train_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#
The training stage has just finished processing a batch.
- Parameters:
_trainer (
Trainer)stage (
Train)outputs (
Any)batch (
Any)_batch_idx (
int)
- Return type:
None
- on_train_end(trainer, stage, /)[source]#
Training is ending, plot the collected training data.
- Parameters:
trainer (
Trainer)stage (
Train)
- on_validation_batch_end(_trainer, stage, outputs, batch, _batch_idx, /)[source]#
The validation stage has just finished processing a batch.
- Parameters:
_trainer (
Trainer)stage (
Validate)outputs (
Any)batch (
Any)_batch_idx (
int)
- Return type:
None
- class tensorial.reaxkit.ReaxModule(model, loss_fn, optimizer, scheduler=None, metrics=None, jit=True, donate_graph=False, output=('predictions', 'targets'))[source]#
Bases:
Module[InputT,OutputT_co]Tensorial REAX module.
- Parameters:
model (
Module)loss_fn (
Callable[[TypeVar(OutputT_co, covariant=True),TypeVar(InputT)],Array])optimizer (
GradientTransformation|Callable[[],GradientTransformation])scheduler (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]] |None)metrics (
dict[str,Metric|str] |None)output (
Sequence[str] |None)
Init function.
- Parameters:
model (
Module)loss_fn (
Callable[[TypeVar(OutputT_co, covariant=True),TypeVar(InputT)],Array])optimizer (
GradientTransformation|Callable[[],GradientTransformation])scheduler (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]] |None)metrics (
dict[str,Metric|str] |None)output (
Sequence[str] |None)
- static calculate_metrics(predictions, targets, metrics)[source]#
- Parameters:
predictions (
TypeVar(OutputT_co, covariant=True))targets (
TypeVar(OutputT_co, covariant=True))metrics (
dict[str,Metric|str])
- Return type:
dict[str,Metric]
- configure_model(_stage, batch, /)[source]#
Called at the beginning of each stage.
A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing.
- Parameters:
_stage (
Stage)
- property debug: bool#
- on_before_optimizer_step(_optimizer, grad, /)[source]#
Called before
optimizer.step().If using gradient accumulation, the hook is called once the gradients have been accumulated. See:
accumulate_grad_batches.If clipping gradients, the gradients will not have been clipped yet.
- predict_step(batch, _batch_idx, /)[source]#
Make a model prediction and return the result.
- Parameters:
batch (
TypeVar(InputT))_batch_idx (
int)
- Return type:
TypeVar(OutputT_co, covariant=True)
- static step(params, inputs, _targets, model, loss_fn, metrics=None, output=())[source]#
Calculate loss and, optionally metrics.
- Parameters:
params (
PyTree)inputs (
TypeVar(InputT))_targets (
TypeVar(OutputT_co, covariant=True))model (
Callable[[PyTree,TypeVar(InputT)],TypeVar(OutputT_co, covariant=True)])loss_fn (
Callable[[TypeVar(OutputT_co, covariant=True),TypeVar(InputT)],Array])metrics (
MetricCollection|None)output (
tuple[str,...])
- Return type:
tuple[Array,dict]
- test_step(batch, _batch_idx, /)[source]#
Test step.
- Parameters:
batch (
tuple[TypeVar(InputT),TypeVar(OutputT_co, covariant=True)])_batch_idx (
int)
- Return type:
StepOutput|None