Source code for tensorial.gcnn.data._datamodule

from collections.abc import Sequence
from typing import TYPE_CHECKING, Final

import jraph
import reax
from typing_extensions import override

from . import _batching, _common, _dataloader

if TYPE_CHECKING:
    from tensorial import gcnn

__all__ = ("GraphDataModule",)

Dataset = Sequence[jraph.GraphsTuple]


[docs] class GraphDataModule(reax.DataModule): """A data module that serves jraph.GraphsTuples""" def __init__( self, dataset: Sequence[jraph.GraphsTuple], train_val_test_split: Sequence[int | float] = (0.85, 0.05, 0.1), batch_size: int = 32, batch_mode: "gcnn.data.BatchMode | str" = _common.BatchMode.IMPLICIT, ): """Initialize the module Args: dataset: The data loader of all the graphs to use train_val_test_split: The train, validation, and test split. batch_size: The batch size. Defaults to `32`. """ super().__init__() # Params self._train_val_test_split: Final[tuple[int | float, ...]] = tuple(train_val_test_split) self._batch_size: Final[int] = batch_size # State self._dataloader = dataset self.batch_size_per_device = batch_size self.data_train: Dataset | None = None self.data_val: Dataset | None = None self.data_test: Dataset | None = None self._max_padding: "gcnn.data.GraphPadding | None" = None self._batch_mode = batch_mode
[docs] @override def setup(self, stage: "reax.Stage", /) -> None: """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. This method is called by REAX before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to `self.setup()` once the data is prepared and available for use. Args: stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. """ # load and split datasets only if not loaded already if not self.data_train and not self.data_val and not self.data_test: # Split up the data train, val, test = reax.data.random_split( self.rngs, dataset=self._dataloader, lengths=self._train_val_test_split, ) graph_datasets: dict[str, Dataset] = dict(train=train, val=val, test=test) # Calculate the maximum padding to use paddings: "list[gcnn.data.GraphPadding]" = [] # Padding is computed with _batching.max_padding(*paddings) -- few lines below. # In explicit mode, if we used self._batch_size to compute the padding, # the number of padding graphs would follow the implicit batching logic. # Thus, in explicit mode we would obtain for the resulting batch: # batch.n_node.shape == (self._batch_size, self._batch_size + 1) # With this condition, we instead enforce: # batch.n_node.shape == (self._batch_size, 2) # which is the expected behavior for explicit batching. for graphs in graph_datasets.values(): if self._batch_mode is _common.BatchMode.IMPLICIT: paddings.append( _batching.GraphBatcher.calculate_padding(graphs, self._batch_size) ) else: paddings.append(_batching.GraphBatcher.calculate_padding(graphs, 1)) self.data_train = graph_datasets["train"] self.data_val = graph_datasets["val"] self.data_test = graph_datasets["test"] # Calculate a padding that will work for all the datasets. self._max_padding = _batching.max_padding(*paddings)
[docs] @override def train_dataloader(self) -> reax.DataLoader: """Create and return the train dataloader. Returns: The train dataloader. """ if self.data_train is None: raise reax.exceptions.MisconfigurationException( "Must call setup() before requesting the dataloader" ) return _dataloader.GraphLoader( self.data_train, batch_size=self._batch_size, padding=self._max_padding, pad=True, batch_mode=self._batch_mode, )
[docs] @override def val_dataloader(self) -> reax.DataLoader: """Create and return the validation dataloader. Returns: The validation dataloader. """ if self.data_val is None: raise reax.exceptions.MisconfigurationException( "Must call setup() before requesting the dataloader" ) return _dataloader.GraphLoader( self.data_val, batch_size=self.batch_size_per_device, shuffle=False, padding=self._max_padding, pad=True, batch_mode=self._batch_mode, )
[docs] @override def test_dataloader(self) -> reax.DataLoader: """Create and return the test dataloader. Returns: The test dataloader. """ if self.data_test is None: raise reax.exceptions.MisconfigurationException( "Must call setup() before requesting the dataloader" ) return _dataloader.GraphLoader( self.data_test, batch_size=self.batch_size_per_device, shuffle=False, padding=self._max_padding, pad=True, batch_mode=self._batch_mode, )