Source code for mlx_graphs.datasets.karate_club

import mlx.core as mx

from mlx_graphs.data.data import GraphData
from mlx_graphs.datasets.dataset import Dataset
from mlx_graphs.utils.transformations import to_undirected


[docs] class KarateClubDataset(Dataset): """ Zachary's Karate Club netowork dataset from `An Information Flow Model for\ Conflict and Fission in Small Groups\ <https://www.jstor.org/stable/3629752>`_. This is a simple dataset for node classification. The graph has 34 nodes and 156 (undirected) edges. Each node belongs to one of 2 classes. """ def __init__(self): super().__init__(name="karate_club", base_dir=None) def download(self): pass def process(self): edge_index = to_undirected( mx.array( [ (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 10), (0, 11), (0, 12), (0, 13), (0, 17), (0, 19), (0, 21), (0, 31), (1, 2), (1, 3), (1, 7), (1, 13), (1, 17), (1, 19), (1, 21), (1, 30), (2, 3), (2, 7), (2, 8), (2, 9), (2, 13), (2, 27), (2, 28), (2, 32), (3, 7), (3, 12), (3, 13), (4, 6), (4, 10), (5, 6), (5, 10), (5, 16), (6, 16), (8, 30), (8, 32), (8, 33), (9, 33), (13, 33), (14, 32), (14, 33), (15, 32), (15, 33), (18, 32), (18, 33), (19, 33), (20, 32), (20, 33), (22, 32), (22, 33), (23, 25), (23, 27), (23, 29), (23, 32), (23, 33), (24, 25), (24, 27), (24, 31), (25, 31), (26, 29), (26, 33), (27, 33), (28, 31), (28, 33), (29, 32), (29, 33), (30, 32), (30, 33), (31, 32), (31, 33), (32, 33), ] ).transpose() ) node_features = mx.ones([34, 1]) node_labels = mx.array( [ [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ] ] ).transpose() self.graphs = [ GraphData( edge_index=edge_index, node_features=node_features, node_labels=node_labels, ) ]