Source code for mlx_graphs.nn.conv.gat_conv

from typing import Optional

import mlx.core as mx
import mlx.nn as nn

from mlx_graphs.nn.linear import Linear
from mlx_graphs.nn.message_passing import MessagePassing
from mlx_graphs.utils import get_src_dst_features, scatter


[docs] class GATConv(MessagePassing): """Graph Attention Network convolution layer. Args: node_features_dim: Size of input node features out_features_dim: Size of output node embeddings heads: Number of attention heads concat: Whether to use concat of heads or mean reduction bias: Whether to use bias in the node projection negative_slope: Slope for the leaky relu dropout: Probability p for dropout edge_features_dim: Size of edge features Example: .. code-block:: python conv = GATConv(16, 32, heads=2, concat=True) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) node_features = mx.random.uniform(low=0, high=1, shape=(5, 16)) h = conv(edge_index, node_features) >>> h.shape [5, 64] """ def __init__( self, node_features_dim: int, out_features_dim: int, heads: int = 1, concat: bool = True, bias: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_features_dim: Optional[int] = None, **kwargs, ): kwargs.setdefault("aggr", "add") super(GATConv, self).__init__(**kwargs) self.out_features_dim = out_features_dim self.heads = heads self.concat = concat self.negative_slope = negative_slope self.lin_proj = Linear(node_features_dim, heads * out_features_dim, bias=False) glorot_init = nn.init.glorot_uniform() self.att_src = glorot_init(mx.zeros((1, heads, out_features_dim))) self.att_dst = glorot_init(mx.zeros((1, heads, out_features_dim))) if bias: bias_shape = (heads * out_features_dim) if concat else (out_features_dim) self.bias = mx.zeros(bias_shape) if dropout > 0.0: self.dropout = nn.Dropout(dropout) if edge_features_dim is not None: self.edge_lin_proj = Linear( edge_features_dim, heads * out_features_dim, bias=False ) self.edge_att = glorot_init(mx.zeros((1, heads, out_features_dim)))
[docs] def __call__( self, edge_index: mx.array, node_features: mx.array, edge_features: Optional[mx.array] = None, ) -> mx.array: """Computes the forward pass of GATConv. Args: edge_index: input edge index of shape `[2, num_edges]` node_features: input node features edge_features: edge features Returns: mx.array: computed node embeddings """ H, C = self.heads, self.out_features_dim src_feats = dst_feats = self.lin_proj(node_features).reshape(-1, H, C) alpha_src = (src_feats * self.att_src).sum(-1) alpha_dst = (dst_feats * self.att_dst).sum(-1) alpha_src, alpha_dst = get_src_dst_features(edge_index, (alpha_src, alpha_dst)) dst_idx = edge_index[1] node_features = self.propagate( node_features=(src_feats, dst_feats), edge_index=edge_index, message_kwargs={ "alpha_src": alpha_src, "alpha_dst": alpha_dst, "index": dst_idx, "edge_features": edge_features, }, ) if self.concat: node_features = node_features.reshape( -1, self.heads * self.out_features_dim ) else: node_features = mx.mean(node_features, axis=1) if "bias" in self: node_features = node_features + self.bias return node_features
[docs] def message( self, src_features: mx.array, dst_features: mx.array, alpha_src: mx.array, alpha_dst: mx.array, index: mx.array, edge_features: Optional[mx.array] = None, ) -> mx.array: """ Computes a message for each edge in the graph following GAT's propagation rule. Args: src_features: Features of the source nodes. dst_features: Features of the destination nodes (not used in this function but included for compatibility). alpha_src: Precomputed attention values for the source nodes. alpha_dst: Precomputed attention values for the destination nodes. index: 1D array with indices of either src or dst nodes to compute softmax. edge_features: Features of the edges in the graph. Returns: mx.array: The computed messages for each edge in the graph. """ alpha = alpha_src + alpha_dst if edge_features is not None: alpha_edge = self._compute_alpha_edge_features(edge_features) alpha = alpha + alpha_edge alpha = nn.leaky_relu(alpha, self.negative_slope) alpha = scatter(alpha, index, self.num_nodes, aggr="softmax") if "dropout" in self: alpha = self.dropout(alpha) return mx.expand_dims(alpha, -1) * src_features
def _compute_alpha_edge_features(self, edge_features: mx.array): assert all(layer in self for layer in ["edge_lin_proj", "edge_att"]), """Using edge features, GATConv layer should be provided argument `edge_features_dim`.""" if edge_features.ndim == 1: edge_features = edge_features.reshape(-1, 1) edge_features = self.edge_lin_proj(edge_features) edge_features = edge_features.reshape(-1, self.heads, self.out_features_dim) alpha_edge = (edge_features * self.edge_att).sum(axis=-1) return alpha_edge