Source code for tensorial.gcnn.data._dataloader

from collections.abc import Iterator, Sequence
import functools
from typing import TYPE_CHECKING, Final

import beartype
import jaxtyping as jt
import jraph
import reax
from typing_extensions import override

from . import _batching, _common

if TYPE_CHECKING:
    from tensorial import gcnn

__all__ = ("GraphLoader",)


[docs] class GraphLoader(reax.DataLoader[_common.GraphsOrGraphsTuple, _common.GraphsOrGraphsTuple]): """Data loader for graphs""" @jt.jaxtyped(typechecker=beartype.beartype) def __init__( self, *datasets: jraph.GraphsTuple | Sequence[jraph.GraphsTuple] | None, batch_size: int = 1, shuffle: bool = False, pad: bool | None = None, padding: "gcnn.data.GraphPadding | None" = None, batch_mode: "gcnn.data.BatchMode | str" = _common.BatchMode.IMPLICIT, sampler: reax.data.Sampler = None, ): # Params self._batch_size: Final[int] = batch_size self._shuffle: Final[bool] = shuffle # State # If the graphs were supplied as GraphTuples then unbatch them to have a base sequence of # individual graphs per input self._dataset = tuple( jraph.unbatch_np(graphs) if isinstance(graphs, jraph.GraphsTuple) else graphs for graphs in datasets ) self._sampler = self._create_sampler(self._dataset, batch_size, shuffle, sampler) if pad is None: pad = padding is not None create_batcher = functools.partial( _batching.GraphBatcher, batch_size=batch_size, shuffle=shuffle, pad=pad, padding=padding, mode=batch_mode, ) self._batchers: "tuple[gcnn.data.GraphBatcher | None, ...]" = tuple( create_batcher(graph_batch) if graph_batch is not None else None for graph_batch in self._dataset ) @staticmethod def _create_sampler( dataset, batch_size: int, shuffle: bool, sampler: reax.data.Sampler | None ) -> reax.data.Sampler: if sampler is not None: return sampler example = next(filter(lambda g: g is not None, dataset)) return reax.data.samplers.create_sampler(example, batch_size=batch_size, shuffle=shuffle) @property def batch_size(self) -> int: return self._batch_size @property def shuffle(self) -> bool: return self._shuffle @override @property def dataset(self): return self._dataset @property def padding(self) -> "gcnn.data.GraphPadding": return self._batchers[0].padding @property def sampler(self) -> "reax.data.Sampler": """Access the index sampler used by the dataloader""" return self._sampler @override def __len__(self) -> int: return len(self.sampler) @override def __iter__(self) -> Iterator[tuple[jraph.GraphsTuple | None, ...]]: for idxs in self._sampler: batch_graphs = tuple( batcher.fetch(idxs) if batcher is not None else None for batcher in self._batchers ) yield batch_graphs
[docs] @override def with_new_sampler(self, sampler: "reax.data.Sampler") -> "GraphLoader": """Recreate the loader with the given index sampler""" return GraphLoader( *self._dataset, batch_size=self.batch_size, shuffle=self.shuffle, padding=self.padding, sampler=sampler, )