from collections.abc import Hashable, Iterable, Mapping, MutableMapping
import numbers
from typing import TYPE_CHECKING
import jraph
import numpy as np
from tensorial.typing import Array, CellType, PbcType
from . import keys
from .. import _spatial as gcnn_graphs
from ... import base
if TYPE_CHECKING:
try:
import ase
except ImportError:
pass
try:
import pymatgen
except ImportError:
pass
__all__ = "graph_from_pymatgen", "graph_from_ase"
# too slow: @jt.jaxtyped(typechecker=beartype.beartype)
[docs]
def graph_from_pymatgen(
pymatgen_structure: "pymatgen.core.SiteCollection",
r_max: numbers.Number,
*,
key_mapping: dict[str, str] | None = None,
atom_include_keys: Iterable | None = ("numbers",),
edge_include_keys: Iterable | None = tuple(),
global_include_keys: Iterable | None = tuple(),
cell: CellType | None = None,
pbc: bool | PbcType | None = None,
graph_globals: dict[str, Array] | None = None,
**kwargs,
) -> jraph.GraphsTuple:
"""Create a jraph Graph from a pymatgen SiteCollection object or subclass
(e.g. Structure, Molecule)
Note that the special atom key "numbers" is used to retrieve atomic numbers using
SiteCollection.atomic_numbers.
All other keys are used to retrieve site properties using SiteCollection.site_properties.
Args:
pymatgen_structure: the SiteCollection object
r_max: the maximum neighbour distance to use when considering
two atoms to be neighbours
key_mapping
atom_include_keys
global_include_keys
cell: an optional unit cell (otherwise will be taken from
Structure.lattice.matrix if it exists)
pbc: an optional periodic boundary conditions array [bool, bool,
bool] (otherwise will be taken from Structure.lattice.pbc if
it exists)
Returns:
the atomic graph
"""
# pylint: disable=too-many-branches
key_mapping = key_mapping or {}
_key_mapping = {
"forces": keys.FORCES,
"energy": keys.TOTAL_ENERGY,
"numbers": keys.ATOMIC_NUMBERS,
}
_key_mapping.update(key_mapping)
key_mapping = _key_mapping
del _key_mapping
positions = pymatgen_structure.cart_coords
if hasattr(pymatgen_structure, "lattice"):
cell = cell or pymatgen_structure.lattice.matrix
pbc = pbc or pymatgen_structure.lattice.pbc
atoms = {}
if "numbers" in atom_include_keys:
atoms[key_mapping.get("numbers", "numbers")] = np.asarray(pymatgen_structure.atomic_numbers)
atom_include_keys = set(atom_include_keys) - {"numbers"}
for key in atom_include_keys:
get_attrs(atoms, pymatgen_structure.site_properties, key, key_mapping)
edges = {}
for key in edge_include_keys:
get_attrs(edges, pymatgen_structure.properties, key, key_mapping)
graph_globals = graph_globals or {}
for key in global_include_keys:
get_attrs(graph_globals, pymatgen_structure.properties, key, key_mapping)
return gcnn_graphs.graph_from_points(
pos=positions,
fractional_positions=False,
r_max=r_max,
cell=cell,
pbc=pbc,
nodes=atoms,
edges=edges,
graph_globals=graph_globals,
**kwargs,
)
# too slow: @jt.jaxtyped(typechecker=beartype.beartype)
[docs]
def graph_from_ase(
ase_atoms: "ase.atoms.Atoms",
r_max: numbers.Number,
*,
key_mapping: dict[str, str] | None = None,
atom_include_keys: Iterable | None = ("numbers",),
edge_include_keys: Iterable | None = tuple(),
global_include_keys: Iterable | None = tuple(),
cell: CellType | None = None,
pbc: bool | PbcType | None = None,
use_calculator: bool = True,
**kwargs,
) -> jraph.GraphsTuple:
"""Create a jraph Graph from an ase.Atoms object
Args:
ase_atoms: the Atoms object
r_max: the maximum neighbour distance to use when considering
two atoms to be neighbours
key_mapping
atom_include_keys
global_include_keys
cell: an optional unit cell (otherwise will be taken from
ase.cell)
pbc: an optional periodic boundary conditions array [bool, bool,
bool] (otherwise will be taken from ase.pbc)
use_calculator: if `True`, will try to use an attached
calculator get additional properties
Returns:
the atomic graph
"""
# pylint: disable=too-many-branches
from ase.calculators import singlepoint
import ase.stress
key_mapping = key_mapping or {}
_key_mapping = {
"forces": keys.FORCES,
"energy": keys.TOTAL_ENERGY,
"numbers": keys.ATOMIC_NUMBERS,
}
_key_mapping.update(key_mapping)
key_mapping = _key_mapping
del _key_mapping
graph_globals = {}
for key in global_include_keys:
get_attrs(graph_globals, ase_atoms.arrays, key, key_mapping)
atoms = {}
for key in atom_include_keys:
get_attrs(atoms, ase_atoms.arrays, key, key_mapping)
edges = {}
for key in edge_include_keys:
get_attrs(edges, ase_atoms.arrays, key, key_mapping)
if use_calculator and ase_atoms.calc is not None:
if not isinstance(
ase_atoms.calc,
(singlepoint.SinglePointCalculator, singlepoint.SinglePointDFTCalculator),
):
raise NotImplementedError(
f"`from_ase` does not support calculator {type(ase_atoms.calc).__name__}"
)
for key, val in ase_atoms.calc.results.items():
if key in atom_include_keys:
atoms[key] = base.atleast_1d(val, np_=np)
elif key in global_include_keys:
graph_globals[key] = base.atleast_1d(val, np_=np)
# Transform ASE-style 6 element Voigt order stress to Cartesian
for key in (keys.STRESS, keys.VIRIAL):
if key in graph_globals:
if graph_globals[key].shape == (3, 3):
# In the format we want
pass
elif graph_globals[key].shape == (6,):
# In Voigt order
graph_globals[key] = ase.stress.voigt_6_to_full_3x3_stress(graph_globals[key])
else:
raise RuntimeError(f"Unexpected shape for {key}, got: {graph_globals[key].shape}")
# cell and pbc in kwargs can override the ones stored in atoms
cell = cell or ase_atoms.get_cell()
pbc = pbc or ase_atoms.pbc
atom_graph = gcnn_graphs.graph_from_points(
pos=ase_atoms.positions,
fractional_positions=False,
r_max=r_max,
cell=cell.__array__() if pbc.any() else None,
pbc=pbc,
nodes=atoms,
edges=edges,
graph_globals=graph_globals,
**kwargs,
)
return atom_graph
def get_attrs(store_in: MutableMapping, get_from: Mapping, key: Hashable, key_map: Mapping) -> bool:
out_key = key_map.get(key, key)
try:
value = get_from[key]
except KeyError:
# Couldn't find the attribute
return False
store_in[out_key] = value
return True