import collections.abc
import io
import logging
import os
import pathlib
import tarfile
import types
from typing import Any, Final, TypedDict
import urllib.request
import ase.data
import jraph
import numpy as np
import tqdm
from .. import base, gcnn
__all__ = ("Qm9",)
MoleculeDict = dict[str, Any]
_LOGGER = logging.getLogger(__name__)
QM9_XYZ_LABELS: Final[list[str]] = [
"tag",
"index",
"A",
"B",
"C",
"mu",
"alpha",
"homo",
"lumo",
"gap",
"r2",
"zpve",
"U0",
"U",
"H",
"G",
"Cv",
]
class GraphOptions(TypedDict):
r_max: float
self_edges: bool
node_attrs: list[str | tuple[str, str]]
graph_attrs: list[str | tuple[str, str]]
np_: types.ModuleType
[docs]
class Qm9(collections.abc.Sequence):
URL: Final[str] = "https://springernature.figshare.com/ndownloader/files/3195389"
FILENAME: Final[str] = "dsgdb9nsd.xyz.tar.bz2"
QM9_STRUCTURES: Final[str] = "qm9_structures"
def __init__(
self,
data_dir: str = "data/",
download: bool = True,
limit: int | None = None,
as_graphs: dict | None = None,
):
# Params
self._data_dir: Final[str] = data_dir
self._download: Final[bool] = download
self._to_graphs: Final[dict] = as_graphs
# State
if download:
self._do_download("/".join([self.URL, self.FILENAME]), self.FILENAME)
archive_path = pathlib.Path(self._data_dir) / self.FILENAME
self._data = self._extract_tarball(archive_path, limit)
def __getitem__(self, item):
entry = self._data[item]
if self._to_graphs and not isinstance(entry, jraph.GraphsTuple):
# Lazily convert the first time
entry = self.to_graph(entry)
self._data[item] = entry
return entry
def __len__(self) -> int:
return len(self._data)
def _do_download(self, url: str, filename: str):
"""Download the file at the URL to our data dir."""
if not os.path.exists(self._data_dir):
os.makedirs(self._data_dir)
out_file = os.path.join(self._data_dir, filename)
if not os.path.isfile(out_file):
local_file, headers = urllib.request.urlretrieve(url, out_file) # nosec
# 1. Check for AWS WAF Challenge specifically
waf_action = headers.get("x-amzn-waf-action")
if waf_action == "challenge":
os.remove(local_file) # Clean up the empty/useless file
raise ConnectionRefusedError(
f"Blocked by AWS WAF: {waf_action}. The server requires a browser challenge. "
f"Could not download {filename}, please try download it in your browser and "
f"saving to {out_file}."
)
# 2. Check Content-Length (since you're getting 0 bytes)
content_length = int(headers.get("Content-Length", -1))
if content_length == 0:
os.remove(local_file)
raise ValueError(
"Download failed: Server returned 0 bytes of data. "
f"Could not download {filename}, please try download it in your browser and "
f"saving to {out_file}."
)
_LOGGER.info("downloaded %s to %s", url, self._data_dir)
def _extract_tarball(self, archive_path, limit=None) -> list[MoleculeDict]:
molecules = []
with tarfile.open(archive_path) as file:
members = file.getmembers()
if limit:
members = members[:limit]
for entry in tqdm.tqdm(members):
file_handle = file.extractfile(entry.name)
out = read_qm9(io.TextIOWrapper(file_handle, encoding="utf-8"))
out["filename"] = entry.name
molecules.append(out)
return molecules
[docs]
def to_graph(self, entry: MoleculeDict) -> jraph.GraphsTuple:
return to_graph(entry, **self._to_graphs)
def _do_extract(archive_path, entry) -> MoleculeDict:
with tarfile.open(archive_path) as file:
file_handle = file.extractfile(entry.name)
out = read_qm9(io.TextIOWrapper(file_handle, encoding="utf-8"))
out["filename"] = entry.name
return out
def read_qm9(file_handle) -> MoleculeDict:
"""Format description can be found here:
https://springernature.figshare.com/articles/dataset/Readme_file_Data_description_for_Quantum_chemistry_structures_and_properties_of_134_kilo_molecules_/1057641?backTo=%2Fcollections%2FQuantum_chemistry_structures_and_properties_of_134_kilo_molecules%2F978904&file=3195392
"""
if isinstance(file_handle, io.BytesIO):
file_handle = io.TextIOWrapper(file_handle, encoding="utf-8")
lines = file_handle.readlines()
num_atoms = int(lines[0].strip())
properties = lines[1].split() # Contains properties like energy, dipole moment, etc.
atoms = [line.split(maxsplit=1) for line in lines[2 : 2 + num_atoms]]
# Parse atomic symbols and positions
species = [atom[0] for atom in atoms]
coords = np.stack(
[np.fromstring(atom[1].replace("*^", "E"), dtype=float, sep=" ")[:3] for atom in atoms]
)
entries = {
"positions": coords,
"species": species,
}
# Now add the properties
for i, (label, prop) in enumerate(zip(QM9_XYZ_LABELS, properties)):
if i == 1:
prop = int(prop)
elif i > 1:
prop = float(prop)
entries[label] = prop
return entries
def to_graph(
entry: MoleculeDict,
r_max: float,
self_edges: bool = False,
node_attrs: list[str | tuple[str, str]] = None,
graph_attrs: list[str | tuple[str, str]] = None,
np_=np,
) -> jraph.GraphsTuple:
n_nodes = len(entry["species"])
# Convert species labels to numbers
atomic_numbers = np.fromiter(
map(lambda symbol: ase.data.atomic_numbers[symbol], entry["species"]), dtype=float
)
node_attrs_ = {gcnn.atomic.ATOMIC_NUMBERS: atomic_numbers}
if node_attrs:
for key in node_attrs:
if isinstance(key, str):
label = key
elif isinstance(key, tuple):
key, label = key
else:
raise ValueError(f"Not attributes key must be str or tuple, got {type(key)}")
attr = base.atleast_1d(entry[key], np_=np_)
if attr.shape[0] != n_nodes:
attr = attr.reshape(n_nodes, attr.shape[0])
node_attrs_[label] = attr
graph_attrs_ = {}
if graph_attrs:
for key in graph_attrs:
if isinstance(key, str):
label = key
elif isinstance(key, tuple):
key, label = key
else:
raise ValueError(f"Not attributes key must be str or tuple, got {type(key)}")
attr = base.atleast_1d(entry[key], np_=np_)
graph_attrs_[label] = attr
return gcnn.graph_from_points(
entry["positions"],
r_max,
cell=None,
fractional_positions=False,
self_interaction=False,
strict_self_interaction=self_edges,
pbc=None,
nodes=node_attrs_,
graph_globals=graph_attrs_,
)