Source code for tensorial.gcnn.random
from collections.abc import Callable
import beartype
import jax.random
import jaxtyping as jt
import jraph
from . import _spatial
RandomFn = Callable[[jax.typing.ArrayLike, int], jax.typing.ArrayLike]
LiteralOrRandom = jax.typing.ArrayLike | RandomFn
[docs]
@jt.jaxtyped(typechecker=beartype.beartype)
def spatial_graph(
rng_key: jax.Array,
num_nodes: int = None,
num_graphs=None,
cutoff=0.4,
nodes: dict[str, LiteralOrRandom] | None = None,
) -> jraph.GraphsTuple | list[jraph.GraphsTuple]:
"""Create graph(s) with nodes that have random positions"""
graphs = []
for _ in range(num_graphs or 1):
if num_nodes is None:
rng_key, subkey = jax.random.split(rng_key)
num_nodes = jax.random.randint(subkey, shape=(), minval=2, maxval=10)
rng_key, subkey = jax.random.split(rng_key)
pos = jax.random.uniform(subkey, shape=(num_nodes, 3))
if nodes is not None:
for key, value in nodes.items():
value, rng_key = _create_attributes(value, rng_key, num_nodes)
nodes[key] = value
graphs.append(_spatial.graph_from_points(pos, r_max=cutoff, nodes=nodes))
if num_graphs is None:
return graphs[0]
return graphs
@jt.jaxtyped(typechecker=beartype.beartype)
def _create_attributes(
value: LiteralOrRandom, rng_key: jax.Array, num: int
) -> tuple[jax.typing.ArrayLike, jax.Array]:
if isinstance(value, Callable):
rng_key, subkey = jax.random.split(rng_key)
return value(subkey, num), rng_key
return value, rng_key