Source code for mlx_graphs.datasets.karate_club
import mlx.core as mx
from mlx_graphs.data.data import GraphData
from mlx_graphs.datasets.dataset import Dataset
from mlx_graphs.utils.transformations import to_undirected
[docs]
class KarateClubDataset(Dataset):
"""
Zachary's Karate Club netowork dataset from `An Information Flow Model for\
Conflict and Fission in Small Groups\
<https://www.jstor.org/stable/3629752>`_. This is a simple
dataset for node classification. The graph has 34 nodes and 156 (undirected)
edges. Each node belongs to one of 2 classes.
"""
def __init__(self):
super().__init__(name="karate_club", base_dir=None)
def download(self):
pass
def process(self):
edge_index = to_undirected(
mx.array(
[
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(0, 5),
(0, 6),
(0, 7),
(0, 8),
(0, 10),
(0, 11),
(0, 12),
(0, 13),
(0, 17),
(0, 19),
(0, 21),
(0, 31),
(1, 2),
(1, 3),
(1, 7),
(1, 13),
(1, 17),
(1, 19),
(1, 21),
(1, 30),
(2, 3),
(2, 7),
(2, 8),
(2, 9),
(2, 13),
(2, 27),
(2, 28),
(2, 32),
(3, 7),
(3, 12),
(3, 13),
(4, 6),
(4, 10),
(5, 6),
(5, 10),
(5, 16),
(6, 16),
(8, 30),
(8, 32),
(8, 33),
(9, 33),
(13, 33),
(14, 32),
(14, 33),
(15, 32),
(15, 33),
(18, 32),
(18, 33),
(19, 33),
(20, 32),
(20, 33),
(22, 32),
(22, 33),
(23, 25),
(23, 27),
(23, 29),
(23, 32),
(23, 33),
(24, 25),
(24, 27),
(24, 31),
(25, 31),
(26, 29),
(26, 33),
(27, 33),
(28, 31),
(28, 33),
(29, 32),
(29, 33),
(30, 32),
(30, 33),
(31, 32),
(31, 33),
(32, 33),
]
).transpose()
)
node_features = mx.ones([34, 1])
node_labels = mx.array(
[
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
1,
1,
0,
0,
1,
0,
1,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
]
]
).transpose()
self.graphs = [
GraphData(
edge_index=edge_index,
node_features=node_features,
node_labels=node_labels,
)
]