Quantum-chemical property prediction with GNNs

Quantum-chemical property prediction with GNNs#

In this mlx-graphs tutorial we explore how to use and graph neural networks to predict quantum-chemical properties of small molecules. You may need to install mlx, mlx-graphs, tqdm and matplotlib to run it.

Loading the dataset#

We will be using the QM9 dataset as provided via TUDataset, which consists of about 130,000 small molecules consisting of 9 heavy atoms drawn from the elements C, H, O, N, F. Each molecule includes spatial information for the single low energy conformation (i.e., 3D coordinates of the atoms, specified in the latest 3 columns in the node_features of the graph) in addition to 13 other features. The dataset consists of 19 regression tasks, corresponding to different quantum-chemical properties of the molecules.

Let’s start by loading the dataset and looking at some of its properties

[1]:
from mlx_graphs.datasets import TUDataset
qm9 = TUDataset(name="QM9")
[2]:
print(f"Number of graphs: {len(qm9)}")
print(f"Number of node features: {qm9.num_node_features}")
print(f"Number of edge features: {qm9.num_edge_features}")
print(f"Number of regression tasks: {qm9[0].graph_labels.shape[1]}")
Number of graphs: 129433
Number of node features: 16
Number of edge features: 4
Number of regression tasks: 19

We’ll split the dataset into training and test sets and create dataloader for them

[3]:
from mlx_graphs.loaders import Dataloader

# training and test splits
num_training_samples = 110_000

training_dataset = qm9[:num_training_samples]
test_dataset = qm9[num_training_samples:]

# dataloaders
batch_size = 128

train_loader = Dataloader(training_dataset, batch_size=batch_size, shuffle=True)
test_loader = Dataloader(test_dataset, batch_size=batch_size, shuffle=False)

Graph Neural Network module#

To dive a bit deeper into the GNN logic and allow for more flexibility, we will implement node and edge processing blocks from scratch (as in https://arxiv.org/pdf/1806.01261.pdf) and then combine them through the GraphNetworkBlock provided by mlx-graphs.

Let’s start with the edge features update model. It will concatenate the edge_features with the node_features of the coreesponding source and target nodes, pass them through a simple MLP (just a linear layer followed by a ReLU) and output the updated edge_features.

[4]:
import mlx.core as mx
import mlx.nn as nn
from mlx_graphs.nn import Linear
from mlx_graphs.utils import scatter
[5]:
class EdgeModel(nn.Module):
    def __init__(
        self,
        edge_features_dim: int,
        node_features_dim: int,
        output_dim: int,
    ):
        super().__init__()
        self.linear = Linear(
            input_dims=2 * node_features_dim + edge_features_dim, # source + target features + edge features
            output_dims=output_dim,
        )

    def __call__(
        self,
        edge_index: mx.array,
        node_features: mx.array,
        edge_features: mx.array,
        graph_features = None
    ):
        source_nodes = edge_index[0]
        destination_nodes = edge_index[1]
        model_input = mx.concatenate(
            [
                node_features[destination_nodes],
                node_features[source_nodes],
                edge_features,
            ],
            1,
        )
        new_edge_features = self.linear(model_input)
        new_edge_features = nn.relu(new_edge_features)
        return new_edge_features

The node features update model will start by aggregating the edge_features corresponding to the edges incoming into each node, by computing their mean. Such aggregated edge features will be concatenated to the corresponding node_features and passed through an MLP to compute the updated node features.

[6]:
class NodeModel(nn.Module):
    def __init__(
        self,
        node_features_dim: int,
        edge_features_dim: int,
        output_dim: int,
    ):
        super().__init__()
        self.linear = Linear(
            input_dims=node_features_dim + edge_features_dim,
            output_dims=output_dim,
        )

    def __call__(
        self,
        edge_index: mx.array,
        node_features: mx.array,
        edge_features: mx.array,
        graph_features = None
    ):
        target_nodes = edge_index[1]
        aggregated_edges = scatter(edge_features, target_nodes, out_size=node_features.shape[0], aggr="mean")
        model_input = mx.concatenate([node_features, aggregated_edges], 1)
        new_node_features = self.linear(model_input)
        new_node_features = nn.relu(new_node_features)
        return new_node_features

Now, the edge and node models we defined can be combined to create a basic graph neural network to sequentially update/transform edges and node features of the input graph.

To do that we can use the GraphNetworkBlock. In the model below we create two graph network blocks that we apply sequentially, we then pool the node features and pass them through a linear transformation to compute the predicted output. We use an output dimension of 1 for the entire model as we’ll use it to learn a single regression task.

[7]:
from mlx_graphs.nn import global_mean_pool, GraphNetworkBlock


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.gnn1 = GraphNetworkBlock(
            node_model=NodeModel(qm9.num_node_features, 64, 64),
            edge_model=EdgeModel(qm9.num_edge_features, qm9.num_node_features, 64),
        )
        self.gnn2 = GraphNetworkBlock(
            node_model=NodeModel(64, 64, 64),
            edge_model=EdgeModel(64, 64, 64),
        )
        self.output_linear = Linear(input_dims=64, output_dims=1)

    def __call__(
        self,
        edge_index: mx.array,
        node_features: mx.array,
        edge_features: mx.array,
        batch_indices: mx.array,
    ):

        node_features, edge_features, _ = self.gnn1(
            edge_index=data.edge_index,
            node_features=data.node_features,
            edge_features=data.edge_features,
        )
        node_features, _, _ = self.gnn2(
            edge_index=data.edge_index,
            node_features=node_features,
            edge_features=edge_features,
        )
        out = global_mean_pool(node_features, batch_indices)
        return self.output_linear(out)

Model training#

Now that we have our model we can simply define our loss function (MSE) and the forward function we’re going to use to train our model along with an optimizer (we’ll use AdamW in this example).

[8]:
import mlx.optimizers as optim

model = Model()

mx.eval(model.parameters())

def loss_fn(y_hat, y, parameters=None):
    return nn.losses.mse_loss(y_hat, y, "sum")

def forward_fn(model, graph, labels):
    y_hat = model(graph.edge_index, graph.node_features, graph.edge_features, graph.batch_indices)
    loss = loss_fn(y_hat, labels, model.parameters())
    return loss, y_hat

loss_and_grad_fn = nn.value_and_grad(model, forward_fn)
optimizer = optim.AdamW(learning_rate=1e-3)

We’re now ready to train our GNN model. Let’s select the “Lowest unoccupied molecular orbital energy” as the target regression task.

[9]:
from tqdm import tqdm

mx.set_default_device(mx.gpu)
regression_task = 4

len_training_loader = int(len(training_dataset)/batch_size + 1)
len_test_loader = int(len(test_dataset)/batch_size + 1)
epochs = 20

epoch_training_losses = []
epoch_test_losses = []
for epoch in range(epochs):
    # training loop
    with tqdm(train_loader, total=len_training_loader, unit="batch") as tepoch:
        tepoch.set_description(f"Epoch {epoch + 1}")
        training_loss = 0.0
        for idx, data in enumerate(tepoch):
            (loss, y_hat), grads = loss_and_grad_fn(
                model=model,
                graph=data,
                labels=mx.expand_dims(data.graph_labels[:, regression_task], 1),
            )
            optimizer.update(model, grads)
            mx.eval(model.parameters(), optimizer.state)
            training_loss += loss.item()/data.num_graphs
            tepoch.set_postfix(training_loss=training_loss / (idx+1), test_loss=0)
        epoch_training_losses.append(training_loss/(idx+1))

    # testing loop
    with tqdm(test_loader, total=len_test_loader, unit="batch", leave=False) as test_epoch:
        test_epoch.set_description(f"Testing")
        test_loss = 0.0
        for idx, data in enumerate(test_epoch):
            y_hat = model(data.edge_index, data.node_features, data.edge_features, data.batch_indices)
            loss = loss_fn(y_hat, mx.expand_dims(data.graph_labels[:, regression_task], 1))
            test_loss += loss.item()/data.num_graphs
            test_epoch.set_postfix(test_loss=test_loss / (idx+1))
        epoch_test_losses.append(test_loss/(idx+1))
Epoch 1: 100%|████████████████████████████████████| 860/860 [00:06<00:00, 137.66batch/s, test_loss=0, training_loss=1.39]
Epoch 2: 100%|███████████████████████████████████| 860/860 [00:06<00:00, 141.06batch/s, test_loss=0, training_loss=0.634]
Epoch 3: 100%|███████████████████████████████████| 860/860 [00:05<00:00, 147.70batch/s, test_loss=0, training_loss=0.406]
Epoch 4: 100%|███████████████████████████████████| 860/860 [00:05<00:00, 151.08batch/s, test_loss=0, training_loss=0.357]
Epoch 5: 100%|███████████████████████████████████| 860/860 [00:05<00:00, 151.76batch/s, test_loss=0, training_loss=0.335]
Epoch 6: 100%|███████████████████████████████████| 860/860 [00:05<00:00, 151.85batch/s, test_loss=0, training_loss=0.316]
Epoch 7: 100%|███████████████████████████████████| 860/860 [00:05<00:00, 154.13batch/s, test_loss=0, training_loss=0.303]
Epoch 8: 100%|███████████████████████████████████| 860/860 [00:05<00:00, 154.84batch/s, test_loss=0, training_loss=0.285]
Epoch 9: 100%|███████████████████████████████████| 860/860 [00:06<00:00, 131.16batch/s, test_loss=0, training_loss=0.264]
Epoch 10: 100%|██████████████████████████████████| 860/860 [00:06<00:00, 135.10batch/s, test_loss=0, training_loss=0.249]
Epoch 11: 100%|██████████████████████████████████| 860/860 [00:06<00:00, 129.22batch/s, test_loss=0, training_loss=0.234]
Epoch 12: 100%|██████████████████████████████████| 860/860 [00:06<00:00, 130.47batch/s, test_loss=0, training_loss=0.218]
Epoch 13: 100%|██████████████████████████████████| 860/860 [00:06<00:00, 131.18batch/s, test_loss=0, training_loss=0.205]
Epoch 14: 100%|██████████████████████████████████| 860/860 [00:05<00:00, 150.44batch/s, test_loss=0, training_loss=0.194]
Epoch 15: 100%|██████████████████████████████████| 860/860 [00:05<00:00, 147.08batch/s, test_loss=0, training_loss=0.182]
Epoch 16: 100%|██████████████████████████████████| 860/860 [00:05<00:00, 150.98batch/s, test_loss=0, training_loss=0.171]
Epoch 17: 100%|██████████████████████████████████| 860/860 [00:05<00:00, 149.20batch/s, test_loss=0, training_loss=0.167]
Epoch 18: 100%|██████████████████████████████████| 860/860 [00:06<00:00, 129.81batch/s, test_loss=0, training_loss=0.162]
Epoch 19: 100%|██████████████████████████████████| 860/860 [00:06<00:00, 127.19batch/s, test_loss=0, training_loss=0.156]
Epoch 20: 100%|██████████████████████████████████| 860/860 [00:06<00:00, 128.22batch/s, test_loss=0, training_loss=0.153]

Now that the training is concluded we can visualize the evolution of the training and test losses.

[10]:
import matplotlib.pyplot as plt
[11]:
plt.figure()
plt.plot(range(epochs), epoch_training_losses, label="Training")
plt.plot(range(epochs), epoch_test_losses, label="Test")
plt.title("Loss")
plt.legend()
plt.show()
../../_images/tutorials_examples_qm9_tutorial_21_0.png

Notes and next steps#

While this is an illustrative tutorial, we hope it gives some insights into mlx-graphs. Next steps could involve implementing more complex and specialized models (e.g., exploiting invariances and equivariances in the geometries of the molecules) or normalizing features and target properties, which can highly improve performance as well.