Source code for tensorial.gcnn.data._common
import collections
import enum
from typing import Any
import jraph
import reax
__all__ = "GraphBatch", "GraphDataset", "GraphPadding", "GraphAttributes", "BatchMode"
GraphsOrGraphsTuple = jraph.GraphsTuple | tuple[jraph.GraphsTuple, ...]
[docs]
class GraphBatch(tuple):
inputs: jraph.GraphsTuple
targets: Any | None
GraphDataset = reax.data.Dataset[GraphBatch] # pylint: disable=invalid-name
GraphPadding = collections.namedtuple("GraphPadding", ["n_nodes", "n_edges", "n_graphs"])
[docs]
class GraphAttributes(enum.IntFlag):
NODES = 0b0001
EDGES = 0b0010
GLOBALS = 0b0100
ALL = NODES | EDGES | GLOBALS
[docs]
class BatchMode(str, enum.Enum):
IMPLICIT = "implicit"
EXPLICIT = "explicit"