Node2vec#

[!WARNING] This tutorial is experimental and requires mlx_cluster to be installed which currently requires mlx 0.18.

Goal: This tutorial will guide you through implementing node2vec to generate vector embeddings for nodes in a simple undirected graph.

Concepts: MLX, Node2vec

[1]:
from collections import defaultdict
import mlx.core as mx
from mlx_graphs.datasets import PlanetoidDataset

Dataset#

For this first tutorial, we will use the PlanetoidDataset collection, which comprises of citation networks for Cora, Pubmed and CiteSeer.

We will be using Cora dataset consisting of 2708 nodes and 10,056 edges. The dataset can be easily accessed via PlanetoidDataset class

[2]:
dataset = PlanetoidDataset("Cora")
Loading cora data ... Done
[3]:
dataset
[3]:
cora(num_graphs=1)

We can access dataset properties directly from datasetobject

[4]:
# Some useful properties
print("Dataset attributes")
print("-" * 20)
print(f"Number of graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of edge features: {dataset.num_edge_features}")
print(f"Number of graph features: {dataset.num_graph_features}")
print(f"Number of graph classes to predict: {dataset.num_graph_classes}\n")

# Statistics of the dataset
stats = defaultdict(list)
for g in dataset:
    stats["Mean node degree"].append(g.num_edges / g.num_nodes)
    stats["Mean num of nodes"].append(g.num_nodes)
    stats["Mean num of edges"].append(g.num_edges)

print("Dataset stats")
print("-" * 20)
for k, v in stats.items():
    mean = mx.mean(mx.array(v)).item()
    print(f"{k}: {mean:.2f}")
Dataset attributes
--------------------
Number of graphs: 1
Number of node features: 1433
Number of edge features: 0
Number of graph features: 0
Number of graph classes to predict: 0

Dataset stats
--------------------
Mean node degree: 3.90
Mean num of nodes: 2708.00
Mean num of edges: 10556.00
[5]:
dataset[0]
[5]:
GraphData(
        edge_index(shape=(2, 10556), int32)
        node_features(shape=(2708, 1433), float32)
        node_labels(shape=(2708,), int32)
        train_mask(shape=(2708,), bool)
        val_mask(shape=(2708,), bool)
        test_mask(shape=(2708,), bool))

Creating a simple neural network using node2vec#

[6]:
from mlx_graphs.algorithms import Node2Vec

Specify hyperparameters for node2vec.#

The most important hyperparameters for node2vec are p and q where 1. p : specifies the likelihood of revisiting a node in the walk (return parameter). When this is low the algorithm is more likely to take a step back. 2. q : specifies likelhood of exploring nodes that are further away from the source. When this is high the algorithm is more likely to explore the neighbourhood 3. embedding_dim: dimemnsions of embedding model 4. walk_length: Number of nodes to consider in a walk 5. context_size: The actual context size which is considered for positive samples. This parameter increases the effective sampling rate by reusing samples across different source nodes.

[7]:
embedding_dim = 128
walk_length = 20
context_size = 10
walks_per_node = 10
num_negative_samples = 1
p = 1.0
q = 3.0
[8]:
model = Node2Vec(
    edge_index=dataset[0].edge_index,
    num_nodes=dataset[0].num_nodes,
    embedding_dim=embedding_dim,
    walk_length=walk_length,
    context_size=context_size,
    walks_per_node=walks_per_node,
    num_negative_samples=num_negative_samples,
    p=p,
    q=q,
    )

Try and train a simple model loop

[9]:
import mlx.optimizers as optim
[10]:
optimizer = optim.Adam(learning_rate=0.001)
[11]:
dataloader = model.dataloader(batch_size=64)
[12]:
import mlx.nn as nn

Creating a simple training loop to train an embedding model#

[13]:
for epoch in range(10):
    total_loss = 0
    dataloader = model.dataloader(batch_size=32)
    for pos, neg in dataloader:
        loss, grad = nn.value_and_grad(model, model.loss)(pos, neg)
        total_loss+=loss.item()
        optimizer.update(model, grad)
    print(f"Epoch : {epoch} batch loss : {total_loss/32:.5f}")
Epoch : 0 batch loss : 2.88000
Epoch : 1 batch loss : 2.24894
Epoch : 2 batch loss : 2.16864
Epoch : 3 batch loss : 2.14403
Epoch : 4 batch loss : 2.13199
Epoch : 5 batch loss : 2.12255
Epoch : 6 batch loss : 2.11843
Epoch : 7 batch loss : 2.11297
Epoch : 8 batch loss : 2.10973
Epoch : 9 batch loss : 2.10481