Source code for mlx_graphs.datasets.dataset

import copy
import os
import pickle
from abc import ABC, abstractmethod
from typing import Callable, Literal, Optional, Sequence, Union

import mlx.core as mx
import numpy as np

from mlx_graphs.data import GraphData, HeteroGraphData

# Default path for downloaded datasets is the current working directory
DEFAULT_BASE_DIR = os.path.join(os.getcwd(), ".mlx_graphs_data/")


[docs] class Dataset(ABC): """ Base dataset class. ``download`` and ``process`` methods must be implemented by children classes. The ``save`` and ``load`` methods save and load only the processed ``self.graphs`` attribute by default. You may want to override them to store/load additional processed attributes. Graph data within the dataset should be stored in ``self.graphs`` as a List[GraphData]. The creation and preprocessing of this list of graphs is typically done within the overridden ``process`` method. Args: name: name of the dataset base_dir: Directory where to store dataset files. Default is in the local directory ``.mlx_graphs_data/``. transform: A function/transform that takes in a ``GraphData`` object and returns a transformed version. The transformation is applied before every access, i.e., during the ``__getitem__`` call. By default, no transformation is applied. """ def __init__( self, name: str, base_dir: Optional[str] = None, pre_transform: Optional[Callable] = None, transform: Optional[Callable] = None, ): self._name = name self._base_dir = base_dir if base_dir else DEFAULT_BASE_DIR self.transform = transform self.pre_transform = pre_transform self.graphs: Union[list[GraphData], list[HeteroGraphData]] = [] self._load() @property def name(self) -> str: """ Name of the dataset """ return self._name @property def raw_path(self) -> str: """ The path where raw files are stored. Defaults at `<base_dir>/<name>/raw` """ return os.path.expanduser(os.path.join(self._base_dir, self.name, "raw")) @property def processed_path(self) -> str: """ The path where raw files are stored. Defaults at `<base_dir>/<name>/processed` """ return os.path.expanduser(os.path.join(self._base_dir, self.name, "processed")) @property def num_graphs(self) -> int: """Returns the number of graphs in the dataset.""" return len(self) @property def num_node_classes(self) -> Union[int, dict[str, int]]: """Returns the number of node classes to predict.""" return self._num_classes("node") @property def num_edge_classes(self) -> Union[int, dict[str, int]]: """Returns the number of edge classes to predict.""" return self._num_classes("edge") @property def num_graph_classes(self) -> Union[int, dict[str, int]]: """Returns the number of graph classes to predict.""" return self._num_classes("graph") @property def num_node_features(self) -> int: """Returns the number of node features.""" return self.graphs[0].num_node_features @property def num_edge_features(self) -> int: """Returns the number of edge features.""" return self.graphs[0].num_edge_features @property def num_graph_features(self) -> int: """Returns the number of graph features.""" return self.graphs[0].num_graph_features
[docs] @abstractmethod def download(self): """Download the dataset at `self.raw_path`.""" pass
[docs] @abstractmethod def process(self): """Process the dataset and store graphs in ``self.graphs``""" pass
[docs] def save(self): """Save the processed dataset""" with open(os.path.join(self.processed_path, "graphs.pkl"), "wb") as f: pickle.dump(self.graphs, f)
[docs] def load(self): """Load the processed dataset""" with open(os.path.join(self.processed_path, "graphs.pkl"), "rb") as f: obj = pickle.load(f) self.graphs = obj
def _download(self): if self._base_dir is not None and self.raw_path is not None: if os.path.exists(self.raw_path): return os.makedirs(self.raw_path, exist_ok=True) print(f"Downloading {self.name} raw data ...", end=" ") self.download() print("Done") def _process(self): self.process() if self.pre_transform: print(f"Applying pre-transform to {self.name} data ...", end=" ") self.graphs = [self.pre_transform(graph) for graph in self.graphs] print("Done") def _save(self): if self._base_dir is not None and self.processed_path is not None: if not os.path.exists(self.processed_path): os.makedirs(self.processed_path, exist_ok=True) print(f"Saving processed {self.name} data ...", end=" ") self.save() print("Done") def _load(self): # try to load the already processed dataset, if unavailable download # and process the raw data and save the processed one try: print(f"Loading {self.name} data ...", end=" ") self.load() print("Done") except FileNotFoundError: self._download() print(f"Processing {self.name} raw data ...", end=" ") self._process() print("Done") self._save() def _num_classes( self, task: Literal["node", "edge", "graph"] ) -> Union[int, dict[str, int]]: flattened_labels = [] num_classes_dict = {} for g in self.graphs: if isinstance(g, GraphData): labels = getattr(g, f"{task}_labels") if labels is not None: flattened_labels.append(labels) elif isinstance(g, HeteroGraphData): if task == "node": labels_dict = g.node_labels_dict if labels_dict is not None: for node_type, labels in labels_dict.items(): if node_type not in num_classes_dict: num_classes_dict[node_type] = [] num_classes_dict[node_type].append(labels) elif task == "edge": labels_dict = g.edge_labels_dict if labels_dict is not None: for edge_type, labels in labels_dict.items(): if edge_type not in num_classes_dict: num_classes_dict[edge_type] = [] num_classes_dict[edge_type].append(labels) else: # task == "graph" labels = g.graph_labels if labels is not None: if None not in num_classes_dict: num_classes_dict[None] = [] num_classes_dict[None].append(labels) if len(flattened_labels) == 0 and len(num_classes_dict) == 0: return 0 else: if len(flattened_labels) > 0: flattened_labels = np.concatenate(flattened_labels) return np.unique(flattened_labels).size else: if task == "node" or task == "edge": return { key: np.unique(np.concatenate(labels)).size for key, labels in num_classes_dict.items() } else: # task == "graph" graph_labels = num_classes_dict.get(None) if graph_labels is not None: return np.unique(np.concatenate(graph_labels)).size else: return 0 def __len__(self): """Number of examples in the dataset""" return len(self.graphs) def __getitem__( self, idx: Union[int, np.integer, slice, mx.array, np.ndarray, Sequence], ) -> Union["Dataset", GraphData, HeteroGraphData]: """ Returns graphs from the ``Dataset`` at given indices. If ``idx`` contains multiple indices (e.g. list or slice), then another ``Dataset`` object containing the corresponding indexed graphs is returned. If ``idx`` is a single index (e.g. int), then a single ``GraphData`` is returned. Args: idx: Indices or index of the graphs to gather from the dataset. Returns: A ``Dataset`` if ``idx`` contains multiple elements, or a ``GraphData`` otherwise. """ indices = range(len(self)) if isinstance(idx, (int, np.integer)) or ( isinstance(idx, mx.array) and idx.ndim == 0 # type: ignore ): index = indices[idx] # type:ignore - idx here is a singleton data = self.graphs[index] if self.transform is not None: data = self.transform(data) return data if isinstance(idx, slice): indices = indices[idx] elif isinstance(idx, mx.array) and idx.dtype in [ # type: ignore mx.int64, mx.int32, mx.int16, ]: return self[idx.flatten().tolist()] # type: ignore - idx is a mx.array elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: return self[idx.flatten().tolist()] elif isinstance(idx, Sequence) and not isinstance(idx, str): indices = [indices[i] for i in idx] else: raise IndexError( f"Dataset indexing failed. Accepted indices are: int, mx.array, " f"list, tuple, np.ndarray (got '{type(idx).__name__}')" ) dataset = copy.copy(self) graphs = [self.graphs[i] for i in indices] if self.transform is not None: graphs = [self.transform(g) for g in graphs] dataset.graphs = graphs return dataset def __repr__(self): return ( self.name if len(self) is None else f"{self.name}(num_graphs={len(self)})" )