tensorial.gcnn.data package#

Module contents#

class tensorial.gcnn.data.BatchMode(*values)[source]#

Bases: str, Enum

EXPLICIT = 'explicit'#
IMPLICIT = 'implicit'#
class tensorial.gcnn.data.GraphAttributes(*values)[source]#

Bases: IntFlag

ALL = 7#
EDGES = 2#
GLOBALS = 4#
NODES = 1#
class tensorial.gcnn.data.GraphBatch(iterable=(), /)[source]#

Bases: tuple

inputs: GraphsTuple#
targets: Any | None#
class tensorial.gcnn.data.GraphBatcher(graphs, batch_size=1, *, shuffle=False, pad=False, add_mask=True, padding=None, pad_to_multiple=None, drop_last=False, mode=BatchMode.IMPLICIT)[source]#

Bases: Iterable[GraphsTuple]

Take an iterable of graphs tuples and break it up into batches

Parameters:
  • graphs (GraphsTuple | Sequence[GraphsTuple])

  • batch_size (int)

  • shuffle (bool)

  • pad (bool)

  • add_mask (bool)

  • padding (GraphPadding | None)

  • pad_to_multiple (int | str | Device | None)

  • drop_last (bool)

  • mode (str | BatchMode)

property batch_size: int#
static calculate_padding(graphs, batch_size, with_shuffle=False, pad_to_multiple=None)[source]#

Calculate the padding necessary to fit the given graphs into a batch

Parameters:
  • graphs (Sequence[GraphsTuple])

  • batch_size (int)

  • with_shuffle (bool)

  • pad_to_multiple (int | str | Device | None)

Return type:

GraphPadding

fetch(idxs)[source]#
Parameters:

idxs (Sequence[int])

Return type:

GraphsTuple

property padding: GraphPadding#
class tensorial.gcnn.data.GraphDataModule(dataset, train_val_test_split=(0.85, 0.05, 0.1), batch_size=32, batch_mode=BatchMode.IMPLICIT)[source]#

Bases: DataModule

A data module that serves jraph.GraphsTuples

Parameters:
  • dataset (Sequence[GraphsTuple])

  • train_val_test_split (Sequence[int | float])

  • batch_size (int)

  • batch_mode (BatchMode | str)

Initialize the module

Parameters:
  • dataset – The data loader of all the graphs to use

  • train_val_test_split – The train, validation, and test split.

  • batch_size – The batch size. Defaults to 32.

  • dataset (Sequence[GraphsTuple])

  • train_val_test_split (Sequence[int | float])

  • batch_size (int)

  • batch_mode (BatchMode | str)

setup(stage, /)[source]#

Load data. Set variables: self.data_train, self.data_val, self.data_test.

This method is called by REAX before trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict(), so be careful not to execute things like random split twice! Also, it is called after self.prepare_data() and there is a barrier in between which ensures that all the processes proceed to self.setup() once the data is prepared and available for use.

Parameters:

stage – The stage to setup. Either “fit”, “validate”, “test”, or “predict”.

Return type:

None

Defaults to None.

Parameters:

stage (Stage)

test_dataloader()[source]#

Create and return the test dataloader.

Return type:

DataLoader

Returns:

The test dataloader.

train_dataloader()[source]#

Create and return the train dataloader.

Return type:

DataLoader

Returns:

The train dataloader.

val_dataloader()[source]#

Create and return the validation dataloader.

Return type:

DataLoader

Returns:

The validation dataloader.

class tensorial.gcnn.data.GraphLoader(*datasets, batch_size=1, shuffle=False, pad=None, padding=None, batch_mode=BatchMode.IMPLICIT, sampler=None)[source]#

Bases: DataLoader[GraphsTuple | tuple[GraphsTuple, …], GraphsTuple | tuple[GraphsTuple, …]]

Data loader for graphs

Parameters:
  • datasets (GraphsTuple | Sequence[GraphsTuple] | None)

  • batch_size (int)

  • shuffle (bool)

  • pad (bool | None)

  • padding (GraphPadding | None)

  • batch_mode (BatchMode | str)

  • sampler (Iterable[TypeVar(IdxT)])

property batch_size: int#
property dataset#

The dataset being loaded.

property padding: GraphPadding#
property sampler: Iterable[IdxT]#

Access the index sampler used by the dataloader

property shuffle: bool#
with_new_sampler(sampler)[source]#

Recreate the loader with the given index sampler

Parameters:

sampler (Iterable[TypeVar(IdxT)])

Return type:

GraphLoader

class tensorial.gcnn.data.GraphPadding(n_nodes, n_edges, n_graphs)#

Bases: tuple

Create new instance of GraphPadding(n_nodes, n_edges, n_graphs)

n_edges#

Alias for field number 1

n_graphs#

Alias for field number 2

n_nodes#

Alias for field number 0

tensorial.gcnn.data.add_padding_mask(graph, mask_field='mask', what=<GraphAttributes.ALL: 7>, overwrite=False, np_=None)[source]#

Add a mask array to the mask_field of graph for either nodes, edges and/or globals which can be used to determine which entries are there just for padding (and therefore should be ignored in any computations).

If overwrite is True then any mask already found in the mask field will be overwritten by the padding mask. Otherwise, it will be ANDed.

Parameters:

graph (GraphsTuple)

Return type:

GraphsTuple

tensorial.gcnn.data.generated_padded_graphs(dataset, add_mask=False, num_nodes=None, num_edges=None, num_graphs=None)[source]#

Provides an iterator over graphs tuple batches that are padded to make the number of nodes, edges and graphs in each batch equal to the maximum found in the dataset

Parameters:

dataset (Iterable[GraphBatch] | Sequence[GraphBatch])

Return type:

Iterator[GraphBatch]

tensorial.gcnn.data.max_padding(*padding)[source]#

Get a padding that contains the maximum number of nodes, edges and graphs over all the provided paddings

Parameters:

padding (GraphPadding)

Return type:

GraphPadding

tensorial.gcnn.data.pad_with_graphs(graph, n_node, n_edge, n_graph=2, mask_field='mask', overwrite_mask=False)[source]#
Parameters:
  • graph (GraphsTuple)

  • n_node (int)

  • n_edge (int)

  • n_graph (int)

  • mask_field (str | None)

Return type:

GraphsTuple