Source code for mlx_graphs.datasets.qm7b

import os
from typing import Optional

import mlx.core as mx
from tqdm import tqdm

from mlx_graphs.data import GraphData
from mlx_graphs.datasets.dataset import Dataset
from mlx_graphs.datasets.utils import check_sha1, download
from mlx_graphs.utils.transformations import to_sparse_adjacency_matrix


[docs] class QM7bDataset(Dataset): """ QM7b dataset from the `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_ paper, consisting of 7,211 molecules with 14 regression targets. Args: base_dir: Directory where to store dataset files. Default is in the local directory ``.mlx_graphs_data/``. """ _url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/qm7b.mat" _sha1_str = "4102c744bb9d6fd7b40ac67a300e49cd87e28392" def __init__(self, base_dir: Optional[str] = None): super().__init__(name="qm7b", base_dir=base_dir) def download(self): file_path = os.path.join(self.raw_path, self.name + ".mat") download(self._url, path=file_path) if not check_sha1(file_path, self._sha1_str): raise UserWarning( "File {} is downloaded but the content hash does not match." "The repo may be outdated or download may be incomplete. " "Otherwise you can create an issue for it.".format(self.name) ) def process(self): try: import scipy as sp except ImportError: raise ImportError("scipy is required to download and process the raw data") assert self.raw_path is not None, "Unable to access/create the self.raw_path" mat_path = os.path.join(self.raw_path, self.name + ".mat") data = sp.io.loadmat(mat_path) labels = mx.array(data["T"].tolist()) features = mx.array(data["X"].tolist()) num_graphs = labels.shape[0] graphs = [] for i in tqdm(range(num_graphs)): edge_index, edge_features = to_sparse_adjacency_matrix(features[i]) graphs.append( GraphData( edge_index=edge_index, edge_features=edge_features, graph_labels=mx.expand_dims(labels[i], 0), ) ) self.graphs = graphs