Source code for mlx_graphs.datasets.superpixel

import os
import pickle
from typing import Literal, Optional, get_args

import mlx.core as mx
import numpy as np
from tqdm import tqdm

from import GraphData
from mlx_graphs.datasets import Dataset
from mlx_graphs.datasets.utils import download, extract_archive
from mlx_graphs.utils import pairwise_distances

SUPERPIXEL_SPLITS = Literal["train", "test"]
    "MNIST": "mnist_75sp",
    "CIFAR10": "cifar10_150sp",

def sigma(distances: mx.array, k: int = 8) -> mx.array:
    Computes the scale parameter sigma defined as the averaged distance xk of the k
    nearest neighbors for each node in the distances matrix.

    See Equation (47) in `<>`_ for more details.

        distances: array of pairwise distances between points
        k: nearest neighbors to consider

        The value of sigma for each node.
    num_nodes = distances.shape[0]

    if k > num_nodes:
        # handle graphs with num_nodes less than k
        sigma = mx.array([1] * num_nodes).reshape(num_nodes, 1)
        # get knns for each node
        knns = mx.partition(distances, k, axis=-1)[:, : k + 1]
        sigma = knns.sum(axis=1).reshape((knns.shape[0], 1)) / k

    return sigma + 1e-8

def image_to_adjacency_matrix(
    coordinates: mx.array, features: mx.array, use_features: bool = True
) -> mx.array:
    Computes a k-NN adjacency matrix as in Equation (47) in

        coordinates: coordinates of each pixel/node
        features: features of each pixel/node
        use_features: whether to use features in computing the adjacency matrix.
            Defaults to True

        Adjacency matrix
    coord = coordinates.reshape(-1, 2)
    coord_dist = pairwise_distances(coord, coord)

    if use_features:
        features_dist = pairwise_distances(features, features)
        adjacency_matrix = mx.exp(
            -((coord_dist / sigma(coord_dist)) ** 2)
            - (features_dist / sigma(features_dist)) ** 2
        adjacency_matrix = mx.exp(-((coord_dist / sigma(coord_dist)) ** 2))

    # convert to symmetric matrix without self-loops
    adjacency_matrix = 0.5 * (adjacency_matrix + adjacency_matrix.T)
    for i in range(adjacency_matrix.shape[0]):
        adjacency_matrix[i, i] = 0
    return adjacency_matrix

def adjacency_matrix_to_knn_edges(
    adjacency_matrix: mx.array, k: int = 9
) -> tuple[mx.array, mx.array]:
    Compute list of knn nodes per node (and features)

        adjacency_matrix: the adjacency matrix
        k: the number of nearest neighbors

        list of knns for each node and list of their features
    num_nodes = adjacency_matrix.shape[0]
    new_kth = num_nodes - k

    np_adj_mat = np.array(adjacency_matrix, copy=False)
    if num_nodes > k:
        # we need to use numpy's argpartition and partition because when there are
        # equal elements in a partition, their order is random with mlx, while when
        # using `numpy` they're ordered based on their index in the original array.
        # This turns out to be a significant problem with highly connected graphs.
        knns = mx.array(
            np.argpartition(np_adj_mat, new_kth - 1, axis=-1)[:, new_kth:-1].tolist()
        knn_values = mx.array(
            np.partition(np_adj_mat, new_kth - 1, axis=-1)[:, new_kth:-1].tolist()
        # for graphs with less than k nodes the resulting graph will be fully connected.
        knns = mx.repeat(mx.arange(num_nodes), num_nodes).reshape(num_nodes, num_nodes)
        knn_values = adjacency_matrix

        # remove self loops
        if num_nodes != 1:
            knn_values = mx.array(
                np_adj_mat[knns != np.arange(num_nodes)[:, None]]
                .reshape(num_nodes, -1)
            knns = mx.array(
                np.array(knns, copy=False)[knns != np.arange(num_nodes)[:, None]]
                .reshape(num_nodes, -1)
    return knns, knn_values

[docs] class SuperPixelDataset(Dataset): """ MNIST and CIFAR10 superpixel datasets for graph classification tasks converted fromt the original MINST and CIFAR10 images. The datasets were introduced in `<>`_. Args: name: name of the selected dataset split: split of the dataset to load use_features: if True, the adjacency matrix is computed from superpixels locations and features. If False, only from superpixels locations. Defaults to False. base_dir: directory where to store the datasets """ _url = "" def __init__( self, name: SUPERPIXEL_NAMES, split: SUPERPIXEL_SPLITS, use_features: bool = False, base_dir: Optional[str] = None, ): assert name in get_args(SUPERPIXEL_NAMES), "Invalid dataset name" assert split in get_args(SUPERPIXEL_SPLITS), "Invalid split specified" self.split = split self.use_features = use_features super().__init__(name=name, base_dir=base_dir) @property def _img_size(self): """Size of dataset image.""" if == "MNIST": return 28 return 32 @property def processed_path(self) -> str: # processed path includes split and use_features return os.path.join( f"{super(self.__class__, self).processed_path}", self.split, "use_features_" + str(self.use_features), ) def download(self): file_path = os.path.join(self.raw_path, "") path = download(self._url, path=file_path) extract_archive(path, self.raw_path, overwrite=True) def process(self): with open( os.path.join( self.raw_path, "superpixels", f"{SUPERPIXEL_PKL_FILES[]}_{self.split}.pkl", ), "rb", ) as f: labels, data = pickle.load(f) labels = mx.array([labels.tolist()]).T for idx, sample in enumerate( tqdm(data, desc=f"Processing {} {self.split} dataset") ): mean_px, coord = sample[:2] mean_px = mx.array(mean_px.tolist()) coord = mx.array(coord.tolist()) / self._img_size if self.use_features: adjacency_matrix = image_to_adjacency_matrix( coord, mean_px ) # using super-pixel locations + features else: adjacency_matrix = image_to_adjacency_matrix( coord, mean_px, False ) # using only super-pixel locations edges_list, edges_values = adjacency_matrix_to_knn_edges(adjacency_matrix) num_nodes = adjacency_matrix.shape[0] mean_px = mean_px.reshape(num_nodes, -1) coord = coord.reshape(num_nodes, 2) node_features = mx.concatenate((mean_px, coord), axis=1) src_nodes = [] dst_nodes = [] # TODO: use mlx once indexing by bool is supported for src, dsts in enumerate(np.array(edges_list, copy=False)): if num_nodes == 1: src_nodes.append(src) dst_nodes.append(dsts) else: dsts = dsts[dsts != src].tolist() srcs = [src] * len(dsts) src_nodes.extend(srcs) dst_nodes.extend(dsts) edge_index = mx.stack([mx.array(src_nodes), mx.array(dst_nodes)]) edge_features = mx.expand_dims(edges_values.reshape(-1), 1) self.graphs.append( GraphData( edge_index=edge_index, node_features=node_features, edge_features=edge_features, graph_labels=labels[idx], ) )