Source code for mlx_graphs.datasets.ogb_dataset

import os
from typing import Literal, Optional, Union, get_args, overload

import mlx.core as mx
import numpy as np

from mlx_graphs.data.data import GraphData
from mlx_graphs.datasets import Dataset
from mlx_graphs.utils import index_to_mask

OGB_NODE_DATASET = Literal[
    "ogbn-products",
    "ogbn-proteins",
    "ogbn-arxiv",
    "ogbn-papers100M",
    # "ogbn-mag", # TODO: requires heterogeneous graphs
]
OGB_EDGE_DATASET = Literal[
    "ogbl-ppa",
    "ogbl-collab",
    "ogbl-ddi",
    "ogbl-citation2",
    # "ogbl-wikikg2",  # TODO: requires heterogeneous graphs
    # "ogbl-biokg",  # TODO: requires heterogeneous graphs
    "ogbl-vessel",
]
OGB_GRAPH_DATASET = Literal[
    "ogbg-molhiv",
    "ogbg-molpcba",
    "ogbg-ppa",
    "ogbg-code2",
]

ALL_DATASETS = Literal[
    "ogbn-products",
    "ogbn-proteins",
    "ogbn-arxiv",
    "ogbn-papers100M",
    # "ogbn-mag", # TODO: requires heterogeneous graphs
    "ogbl-ppa",
    "ogbl-collab",
    "ogbl-ddi",
    "ogbl-citation2",
    # "ogbl-wikikg2",  # TODO: requires heterogeneous graphs
    # "ogbl-biokg",  # TODO: requires heterogeneous graphs
    "ogbl-vessel",
    "ogbg-molhiv",
    "ogbg-molpcba",
    "ogbg-ppa",
    "ogbg-code2",
]


@overload
def to_mx_array(x: None) -> None:
    ...


@overload
def to_mx_array(x: np.ndarray) -> mx.array:
    ...


def to_mx_array(x: Union[np.ndarray, None]) -> Union[mx.array, None]:
    if x is None:
        return None
    else:
        return mx.array(x.tolist())


[docs] class OGBDataset(Dataset): """ Datasets from the `Open Graph Benchmark (OGB) <https://ogb.stanford.edu>`_ collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs. Datasets belongs to three fundamental graph machine learning task categories: predicting the properties of nodes, links, and graphs. Node property prediction datasets consist of a single graph with three additional properties: `train_mask`, `val_mask` and `test_mask` specifying the masks for the train, validation and test splits. Link property prediction datasets also consist of a single graphs with three additional properties: `train_edge_index`, `val_edge_index` and `test_edge_index`, specifying the edges to be considered for training, validation and testing. Graph property prediction datasets consists of multiple graphs. The desired split can be specified via the `split` arg. See `here <https://ogb.stanford.edu/docs/dataset_overview/>`_ for further details and a list of the available datasets with their descriptions Args: name: Name of the dataset split: Split of the dataset to load. Thi parameter has effect only on graph property prediction dataset. If `None`, the entire dataset is loaded. Defaults to `None`. base_dir: Directory where to store dataset files. Default is in the local directory ``.mlx_graphs_data/``. .. note:: `ogb` needs to be installed to use this dataset .. note:: The ogbn-mag, ogbl-wikikg2 and igbl-biokg and the graphs belonging to the largs-scale challenge category are currently not available as they require heterogenous graphs which are not yet supported by `mlx-graphs` Example: .. code-block:: python from mlx_graphs.datasets.ogb_dataset import OGBDataset ds = OGBDataset("ogbg-molhiv", split="train") >>> ogbg-molhiv(num_graphs=32901) """ def __init__( self, name: ALL_DATASETS, split: Optional[Literal["train", "val", "test"]] = None, base_dir: Optional[str] = None, ): self.split = split self._raw_ogb_dataset = None super().__init__(name, base_dir) @property def processed_path(self) -> str: # processed path includes split if self.name in get_args(OGB_GRAPH_DATASET): return os.path.join( f"{super(self.__class__, self).processed_path}", self.split if self.split is not None else "full", ) else: return super(self.__class__, self).processed_path def download(self): self._load_data() def _load_data(self): try: if self.name in get_args(OGB_NODE_DATASET): from ogb.nodeproppred import NodePropPredDataset as OGB_dataset elif self.name in get_args(OGB_EDGE_DATASET): from ogb.linkproppred import LinkPropPredDataset as OGB_dataset elif self.name in get_args(OGB_GRAPH_DATASET): from ogb.graphproppred import GraphPropPredDataset as OGB_dataset self._raw_ogb_dataset = OGB_dataset(name=self.name, root=self.raw_path) except ImportError: raise ImportError( "ogb needs to be installed to use this dataset", "you can install it via pip install ogb", ) def process(self): if self._raw_ogb_dataset is not None: dataset = self._raw_ogb_dataset else: # reload if already downloaded self._load_data() dataset = self._raw_ogb_dataset if self.name in get_args(OGB_NODE_DATASET): graph, label = dataset[0] # type: ignore - it's NodePropPredDataset split_idx: dict = dataset.get_idx_split() # type: ignore train_idx, val_idx, test_idx = ( split_idx["train"], split_idx["valid"], split_idx["test"], ) num_nodes = graph["num_nodes"] self.graphs.append( GraphData( edge_index=to_mx_array(graph["edge_index"]), node_features=to_mx_array(graph["node_feat"]), edge_features=to_mx_array(graph["edge_feat"]), node_labels=to_mx_array(label), # type: ignore train_mask=index_to_mask(to_mx_array(train_idx), size=num_nodes), val_mask=index_to_mask(to_mx_array(val_idx), size=num_nodes), test_mask=index_to_mask(to_mx_array(test_idx), size=num_nodes), ) ) elif self.name in get_args(OGB_EDGE_DATASET): graph = dataset[0] # type: ignore - it's LinkPropPredDataset split_idx: dict = dataset.get_edge_split() # type: ignore train_edges, val_edges, test_edges = ( split_idx["train"]["edge"].reshape(2, -1), split_idx["valid"]["edge"].reshape(2, -1), split_idx["test"]["edge"].reshape(2, -1), ) self.graphs.append( GraphData( edge_index=to_mx_array(graph["edge_index"]), # type: ignore node_features=to_mx_array(graph["node_feat"]), # type: ignore edge_features=to_mx_array(graph["edge_feat"]), # type: ignore train_edge_index=to_mx_array(train_edges), val_edge_index=to_mx_array(val_edges), test_edge_index=to_mx_array(test_edges), ) ) elif self.name in get_args(OGB_GRAPH_DATASET): for graph, label in dataset: # type: ignore - it's GraphPropPredDataset self.graphs.append( GraphData( edge_index=to_mx_array(graph["edge_index"]), node_features=to_mx_array(graph["node_feat"]), edge_features=to_mx_array(graph["edge_feat"]), node_labels=to_mx_array(label), # type: ignore ) ) split_idx: dict = dataset.get_idx_split() # type: ignore train_idx, val_idx, test_idx = ( split_idx["train"], split_idx["valid"], split_idx["test"], ) if self.split == "train": self.graphs = [self.graphs[i] for i in train_idx] elif self.split == "val": self.graphs = [self.graphs[i] for i in val_idx] elif self.split == "test": self.graphs = [self.graphs[i] for i in test_idx]