Source code for mlx_graphs.nn.conv.gcn_conv
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_graphs.nn.message_passing import MessagePassing
from mlx_graphs.utils import add_self_loops, degree, invert_sqrt_degree
[docs]
class GCNConv(MessagePassing):
"""Applies a GCN convolution over input node features.
Args:
node_features_dim: size of input node features
out_features_dim: size of output node embeddings
bias: whether to use bias in the node projection
add_self_loops: whether to add a self-loop for each node
"""
def __init__(
self,
node_features_dim: int,
out_features_dim: int,
bias: bool = True,
add_self_loops: bool = False,
**kwargs,
):
kwargs.setdefault("aggr", "add")
super(GCNConv, self).__init__(**kwargs)
self.linear = nn.Linear(node_features_dim, out_features_dim, bias)
self._add_self_loops = add_self_loops
[docs]
def __call__(
self,
edge_index: mx.array,
node_features: mx.array,
edge_weights: Optional[mx.array] = None,
normalize: bool = True,
**kwargs: Any,
) -> mx.array:
assert edge_index.shape[0] == 2, "edge_index must have shape (2, num_edges)"
assert (
edge_index[1].size > 0
), "'col' component of edge_index should not be empty"
node_features = self.linear(node_features)
if self._add_self_loops:
edge_index = add_self_loops(edge_index)
row, col = edge_index
# Compute node degree normalization for the mean aggregation.
norm: Optional[mx.array] = None
if normalize:
deg = degree(col, node_features.shape[0], edge_weights=edge_weights)
# NOTE : need boolean indexing in order to zero out inf values
deg_inv_sqrt = invert_sqrt_degree(deg)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Compute messages and aggregate them with sum and norm.
node_features = self.propagate(
edge_index=edge_index,
node_features=node_features,
message_kwargs={"edge_weights": norm},
)
return node_features