Node2vec#
[!WARNING] This tutorial is experimental and requires
mlx_cluster
to be installed which currently requires mlx 0.18.
Goal: This tutorial will guide you through implementing node2vec to generate vector embeddings for nodes in a simple undirected graph.
Concepts: MLX
, Node2vec
[1]:
from collections import defaultdict
import mlx.core as mx
from mlx_graphs.datasets import PlanetoidDataset
Dataset#
For this first tutorial, we will use the PlanetoidDataset collection, which comprises of citation networks for Cora
, Pubmed
and CiteSeer
.
We will be using Cora
dataset consisting of 2708 nodes and 10,056 edges. The dataset can be easily accessed via PlanetoidDataset
class
[2]:
dataset = PlanetoidDataset("Cora")
Loading cora data ... Done
[3]:
dataset
[3]:
cora(num_graphs=1)
We can access dataset properties directly from dataset
object
[4]:
# Some useful properties
print("Dataset attributes")
print("-" * 20)
print(f"Number of graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of edge features: {dataset.num_edge_features}")
print(f"Number of graph features: {dataset.num_graph_features}")
print(f"Number of graph classes to predict: {dataset.num_graph_classes}\n")
# Statistics of the dataset
stats = defaultdict(list)
for g in dataset:
stats["Mean node degree"].append(g.num_edges / g.num_nodes)
stats["Mean num of nodes"].append(g.num_nodes)
stats["Mean num of edges"].append(g.num_edges)
print("Dataset stats")
print("-" * 20)
for k, v in stats.items():
mean = mx.mean(mx.array(v)).item()
print(f"{k}: {mean:.2f}")
Dataset attributes
--------------------
Number of graphs: 1
Number of node features: 1433
Number of edge features: 0
Number of graph features: 0
Number of graph classes to predict: 0
Dataset stats
--------------------
Mean node degree: 3.90
Mean num of nodes: 2708.00
Mean num of edges: 10556.00
[5]:
dataset[0]
[5]:
GraphData(
edge_index(shape=(2, 10556), int32)
node_features(shape=(2708, 1433), float32)
node_labels(shape=(2708,), int32)
train_mask(shape=(2708,), bool)
val_mask(shape=(2708,), bool)
test_mask(shape=(2708,), bool))
Creating a simple neural network using node2vec#
[6]:
from mlx_graphs.algorithms import Node2Vec
Specify hyperparameters for node2vec.#
The most important hyperparameters for node2vec are p
and q
where 1. p
: specifies the likelihood of revisiting a node in the walk (return parameter). When this is low the algorithm is more likely to take a step back. 2. q
: specifies likelhood of exploring nodes that are further away from the source. When this is high the algorithm is more likely to explore the neighbourhood 3. embedding_dim
: dimemnsions of embedding model 4. walk_length
: Number of nodes to consider in a
walk 5. context_size
: The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.
[7]:
embedding_dim = 128
walk_length = 20
context_size = 10
walks_per_node = 10
num_negative_samples = 1
p = 1.0
q = 3.0
[8]:
model = Node2Vec(
edge_index=dataset[0].edge_index,
num_nodes=dataset[0].num_nodes,
embedding_dim=embedding_dim,
walk_length=walk_length,
context_size=context_size,
walks_per_node=walks_per_node,
num_negative_samples=num_negative_samples,
p=p,
q=q,
)
Try and train a simple model loop
[9]:
import mlx.optimizers as optim
[10]:
optimizer = optim.Adam(learning_rate=0.001)
[11]:
dataloader = model.dataloader(batch_size=64)
[12]:
import mlx.nn as nn
Creating a simple training loop to train an embedding model#
[13]:
for epoch in range(10):
total_loss = 0
dataloader = model.dataloader(batch_size=32)
for pos, neg in dataloader:
loss, grad = nn.value_and_grad(model, model.loss)(pos, neg)
total_loss+=loss.item()
optimizer.update(model, grad)
print(f"Epoch : {epoch} batch loss : {total_loss/32:.5f}")
Epoch : 0 batch loss : 2.88000
Epoch : 1 batch loss : 2.24894
Epoch : 2 batch loss : 2.16864
Epoch : 3 batch loss : 2.14403
Epoch : 4 batch loss : 2.13199
Epoch : 5 batch loss : 2.12255
Epoch : 6 batch loss : 2.11843
Epoch : 7 batch loss : 2.11297
Epoch : 8 batch loss : 2.10973
Epoch : 9 batch loss : 2.10481