Graph Classification

Contents

Graph Classification#

Goal: This first tutorial will go through an example of graph classification using mini-batching. We will explore how to generate node embeddings using a Graph Convolutional Network (GCN) model, subsequently transform these embeddings through readout, and enhance efficiency by parallelizing operations via graph batching.

Concepts: Mini-batching, Readout, GNN training, MLX

[1]:
from collections import defaultdict

import mlx.core as mx
from mlx_graphs.datasets import TUDataset

Dataset#

For this first tutorial, we will use the TUDatasets collection, which comprises more than 120 datasets for graph classification and graph regression tasks.

The datasets proposed in this collection can be easily accessed via the TUDataset class.

We will use here the MUTAG dataset, where input graphs represent chemical compounds, with vertices symbolizing atoms identified by their atom type through one-hot encoding. Edges between vertices denote the bonds connecting the atoms. The dataset comprises 188 samples of chemical compounds, featuring 7 distinct node labels.

[2]:
dataset = TUDataset("MUTAG")
dataset
[2]:
MUTAG(num_graphs=188)

Dataset properties can directly accessed from the dataset object, and we can also compute some statistics to better understand the dataset.

[3]:
# Some useful properties
print("Dataset attributes")
print("-" * 20)
print(f"Number of graphs: {len(dataset)}")
print(f"Number of node features: {dataset.num_node_features}")
print(f"Number of edge features: {dataset.num_edge_features}")
print(f"Number of graph features: {dataset.num_graph_features}")
print(f"Number of graph classes to predict: {dataset.num_graph_classes}\n")

# Statistics of the dataset
stats = defaultdict(list)
for g in dataset:
    stats["Mean node degree"].append(g.num_edges / g.num_nodes)
    stats["Mean num of nodes"].append(g.num_nodes)
    stats["Mean num of edges"].append(g.num_edges)

print("Dataset stats")
print("-" * 20)
for k, v in stats.items():
    mean = mx.mean(mx.array(v)).item()
    print(f"{k}: {mean:.2f}")
Dataset attributes
--------------------
Number of graphs: 188
Number of node features: 7
Number of edge features: 4
Number of graph features: 0
Number of graph classes to predict: 2

Dataset stats
--------------------
Mean node degree: 2.19
Mean num of nodes: 17.93
Mean num of edges: 39.59

A Dataset is nothing more than a wrapper around a list of GraphData objects. In mlx-graphs, a GraphData object contains the structure along with features of a graph, similarly as DGLGraph in DGL or Data in PyG.

We can directly access these graphs from the dataset using indexing.

[4]:
dataset[0]
[4]:
GraphData(
        edge_index(shape=(2, 38), int32)
        node_features(shape=(17, 7), float32)
        edge_features(shape=(38, 4), float32)
        graph_labels(shape=(1,), int32))

The first graph of this dataset comprises 38 edges with 4 edge features and 17 nodes with 7 node features.

When indexing a dataset with sequences or slices, we end up with another Dataset object containing the graphs associated with this sequence. Using this indexing strategy, the dataset can be divided into train and test sets.

[5]:
train_dataset = dataset[:150]
test_dataset = dataset[150:]

print(f"Training dataset: {train_dataset}")
print(f"Testing dataset: {test_dataset}")
Training dataset: MUTAG(num_graphs=150)
Testing dataset: MUTAG(num_graphs=38)

We use a Dataloader to divide the datasets into iterable batches of graphs, a technique highly recommended for its ability to enhance parallelization of operations, thereby accelerating runtime.

Within each batch, all attributes are basically concatenated, allowing multiple graphs to be represented through a single array per attribute. Importantly, this maintains the independence of each graph (i.e., the graphs remain unconnected to one another). To identify and extract the original graphs from a batch, each GraphDataBatch includes a batch_indices attribute. This attribute provides a mapping for all nodes within the batch back to their respective graphs, facilitating easy retrieval of individual graph data from the batch structure.

[6]:
from mlx_graphs.loaders import Dataloader

BATCH_SIZE = 64

train_loader = Dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = Dataloader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

for batch in train_loader:
    print(f"\nGraph batch of size {len(batch)}")
    print(batch)
    print(batch.batch_indices)

Graph batch of size 64
GraphDataBatch(
        edge_index(shape=(2, 2590), int32)
        node_features(shape=(1168, 7), float32)
        edge_features(shape=(2590, 4), float32)
        graph_labels(shape=(64,), int32))
array([0, 0, 0, ..., 63, 63, 63], dtype=int64)

Graph batch of size 64
GraphDataBatch(
        edge_index(shape=(2, 2620), int32)
        node_features(shape=(1179, 7), float32)
        edge_features(shape=(2620, 4), float32)
        graph_labels(shape=(64,), int32))
array([0, 0, 0, ..., 63, 63, 63], dtype=int64)

Graph batch of size 22
GraphDataBatch(
        edge_index(shape=(2, 720), int32)
        node_features(shape=(337, 7), float32)
        edge_features(shape=(720, 4), float32)
        graph_labels(shape=(22,), int32))
array([0, 0, 0, ..., 21, 21, 21], dtype=int64)

GCN model#

Let’s define a basic 3-layer Graph Convolutional Network (GCN) using the GCNConv layer. It is as simple as creating a new class inheriting from mlx.nn.Module and implementing the __call__ method, responsible for the forward pass.

We employ global_mean_pool, also known as a readout operation, to compute the mean of all node embeddings, resulting in a graph embedding that we can pass as input to a final linear layer for graph classification. As we are working with batches of graphs, we need to provide the batch_indices of the batch to global_mean_pool in order to compute the pooling operation individually for each batch. The output of the pooling operation will thus be (num_batches, embedding_size), here (64, 64).

[7]:
import mlx.nn as nn
from mlx_graphs.nn import GCNConv, global_mean_pool, Linear
import time


class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.5):
        super(GCN, self).__init__()

        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.linear = Linear(hidden_dim, out_dim)

        self.dropout = nn.Dropout(p=dropout)

    def __call__(self, edge_index, node_features, batch_indices):
        h = nn.relu(self.conv1(edge_index, node_features))
        h = nn.relu(self.conv2(edge_index, h))
        h = self.conv3(edge_index, h)

        h = global_mean_pool(h, batch_indices)

        h = self.dropout(h)
        h = self.linear(h)

        return h

We will train our GCN model in a supervised fashion using cross entropy loss. Here’s how we define the loss function and the forward function in MLX. It is recommended in MLX to write a dedicated function for the forward pass as this function will later by passed to nn.value_and_grad in order to compute the gradients of the model w.r.t the output loss of the forward function.

[8]:
def loss_fn(y_hat, y, parameters=None):
    return mx.mean(nn.losses.cross_entropy(y_hat, y))

def forward_fn(model, graph, labels):
    y_hat = model(graph.edge_index, graph.node_features, graph.batch_indices)
    loss = loss_fn(y_hat, labels, model.parameters())
    return loss, y_hat

By default, MLX computations are performed on the Mac’s integrated GPU, leveraging its multiple cores for efficient operations. This is the preferred method for mlx-graphs to optimize parallel processing of GNN tasks. However, you can effortlessly switch between computing on the CPU and the GPU using the following method:

[9]:
device = mx.gpu # or mx.cpu
mx.set_default_device(device)

Training the GCN model is done similarly as in other frameworks like Jax or PyTorch.

[10]:
def train(train_loader):
    loss_sum = 0.0
    for graph in train_loader:

        (loss, y_hat), grads = loss_and_grad_fn(
            model=model,
            graph=graph,
            labels=graph.graph_labels,
        )
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        loss_sum += loss.item()
    return loss_sum / len(train_loader.dataset)

def test(loader):
    acc = 0.0
    for graph in loader:
        y_hat = model(graph.edge_index, graph.node_features, graph.batch_indices)
        y_hat = y_hat.argmax(axis=1)
        acc += (y_hat == graph.graph_labels).sum().item()

    return acc / len(loader.dataset)

def epoch():
    loss = train(train_loader)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    return loss, train_acc, test_acc
[11]:
import mlx.optimizers as optim
mx.random.seed(42)

model = GCN(
    in_dim=dataset.num_node_features,
    hidden_dim=64,
    out_dim=dataset.num_graph_classes,
)
mx.eval(model.parameters())

optimizer = optim.Adam(learning_rate=0.01)
loss_and_grad_fn = nn.value_and_grad(model, forward_fn)

epochs = 30
best_test_acc = 0.0
for e in range(epochs):
    loss, train_acc, test_acc = epoch()
    best_test_acc = max(best_test_acc, test_acc)

    print(
        " | ".join(
            [
                f"Epoch: {e:3d}",
                f"Train loss: {loss:.3f}",
                f"Train acc: {train_acc:.3f}",
                f"Test acc: {test_acc:.3f}",
            ]
        )
    )
print(f"\n==> Best test accuracy: {best_test_acc:.3f}")
Epoch:   0 | Train loss: 0.022 | Train acc: 0.387 | Test acc: 0.447
Epoch:   1 | Train loss: 0.014 | Train acc: 0.667 | Test acc: 0.684
Epoch:   2 | Train loss: 0.013 | Train acc: 0.640 | Test acc: 0.658
Epoch:   3 | Train loss: 0.013 | Train acc: 0.660 | Test acc: 0.658
Epoch:   4 | Train loss: 0.013 | Train acc: 0.633 | Test acc: 0.684
Epoch:   5 | Train loss: 0.013 | Train acc: 0.507 | Test acc: 0.526
Epoch:   6 | Train loss: 0.013 | Train acc: 0.647 | Test acc: 0.684
Epoch:   7 | Train loss: 0.013 | Train acc: 0.580 | Test acc: 0.658
Epoch:   8 | Train loss: 0.013 | Train acc: 0.633 | Test acc: 0.658
Epoch:   9 | Train loss: 0.012 | Train acc: 0.680 | Test acc: 0.684
Epoch:  10 | Train loss: 0.013 | Train acc: 0.753 | Test acc: 0.605
Epoch:  11 | Train loss: 0.013 | Train acc: 0.700 | Test acc: 0.737
Epoch:  12 | Train loss: 0.010 | Train acc: 0.700 | Test acc: 0.711
Epoch:  13 | Train loss: 0.012 | Train acc: 0.647 | Test acc: 0.711
Epoch:  14 | Train loss: 0.013 | Train acc: 0.713 | Test acc: 0.632
Epoch:  15 | Train loss: 0.012 | Train acc: 0.687 | Test acc: 0.658
Epoch:  16 | Train loss: 0.011 | Train acc: 0.733 | Test acc: 0.605
Epoch:  17 | Train loss: 0.011 | Train acc: 0.687 | Test acc: 0.711
Epoch:  18 | Train loss: 0.013 | Train acc: 0.747 | Test acc: 0.579
Epoch:  19 | Train loss: 0.012 | Train acc: 0.640 | Test acc: 0.684
Epoch:  20 | Train loss: 0.013 | Train acc: 0.747 | Test acc: 0.658
Epoch:  21 | Train loss: 0.010 | Train acc: 0.733 | Test acc: 0.658
Epoch:  22 | Train loss: 0.011 | Train acc: 0.713 | Test acc: 0.658
Epoch:  23 | Train loss: 0.011 | Train acc: 0.707 | Test acc: 0.684
Epoch:  24 | Train loss: 0.010 | Train acc: 0.720 | Test acc: 0.711
Epoch:  25 | Train loss: 0.011 | Train acc: 0.727 | Test acc: 0.684
Epoch:  26 | Train loss: 0.011 | Train acc: 0.727 | Test acc: 0.579
Epoch:  27 | Train loss: 0.012 | Train acc: 0.713 | Test acc: 0.632
Epoch:  28 | Train loss: 0.011 | Train acc: 0.720 | Test acc: 0.605
Epoch:  29 | Train loss: 0.011 | Train acc: 0.707 | Test acc: 0.711

==> Best test accuracy: 0.737