Source code for tensorial.reaxkit.cli
"""Main CLI command"""
import argparse
import os
import pathlib
import sys
from typing import Final, cast
import hydra
from tensorial import reaxkit as rkit
COMMAND: Final[str] = "command"
TRAIN: Final[str] = "train"
PREDICT: Final[str] = "predict"
TRAIN_SCRIPT_DEFAULT: Final[str] = "configs/train.yaml"
EVAL_SCRIPT_DEFAULT: Final[str] = "configs/eval.yaml"
REAX_COMMAND: Final[str] = "REAX_COMMAND"
[docs]
def main_cli():
os.environ[REAX_COMMAND] = " ".join(sys.argv)
parser = argparse.ArgumentParser("tensorial")
commands = parser.add_subparsers(dest=COMMAND, required=True)
# The 'train' command
train_parser = commands.add_parser(TRAIN, help="Train a model")
train_parser.add_argument(
"-i",
"--input",
nargs="?",
type=pathlib.Path,
help="Input file with training details",
default=TRAIN_SCRIPT_DEFAULT,
)
# The 'predict' command
train_parser = commands.add_parser(PREDICT, help="Make predictions using a trained model")
train_parser.add_argument(
"-i",
"--input",
nargs="?",
type=pathlib.Path,
help="Input file with evaluation details",
default=EVAL_SCRIPT_DEFAULT,
)
# Parse the args
args, _rest = parser.parse_known_args()
if args.command == TRAIN:
# Set the command line arguments to what remains so hydra can deal with it
sys.argv = sys.argv[0:1] + _rest
script_path: pathlib.Path = args.input
hydra_fn = hydra.main(
version_base="1.3",
config_path=str(script_path.parent.absolute()),
config_name=script_path.stem,
)(rkit.train.main)
elif args.command == PREDICT:
# Set the command line arguments to what remains so hydra can deal with it
sys.argv = sys.argv[0:1] + _rest
script_path = cast(pathlib.Path, args.input)
if script_path.is_dir():
script_path = script_path / rkit.config.DEFAULT_CONFIG_FILE
if not script_path.is_file():
print(f"Could not find configuration file: {script_path}")
sys.exit(1)
hydra_fn = hydra.main(
version_base="1.3",
config_path=str(script_path.parent.absolute()),
config_name=script_path.stem,
)(rkit.evaluate.main)
else:
raise ValueError(f"Unrecognised command '{args.command}'")
# Call Hydra to launch the actual command
hydra_fn()
if __name__ == "__main__":
main_cli()