Source code for mlx_graphs.nn.conv.gatv2_conv

from typing import Optional

import mlx.core as mx
import mlx.nn as nn

from mlx_graphs.nn.linear import Linear
from mlx_graphs.nn.message_passing import MessagePassing
from mlx_graphs.utils import get_src_dst_features, scatter


[docs] class GATv2Conv(MessagePassing): """Graph Attention Network convolution layer with dynamic attention. Modification of GATConv based off of `"How Attentive are Graph Attention Networks?" <https://arxiv.org/abs/2105.14491>`_ paper. Args: node_features_dim: Size of input node features out_features_dim: Size of output node embeddings heads: Number of attention heads concat: Whether to use concat of heads or mean reduction bias: Whether to use bias in the node projection negative_slope: Slope for the leaky relu dropout: Probability p for dropout edge_features_dim: Size of edge features Example: .. code-block:: python conv = GATv2Conv(16, 32, heads=2, concat=True) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) node_features = mx.random.uniform(low=0, high=1, shape=(5, 16)) h = conv(edge_index, node_features) >>> h.shape [5, 64] """ def __init__( self, node_features_dim: int, out_features_dim: int, heads: int = 1, concat: bool = True, bias: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_features_dim: Optional[int] = None, **kwargs, ): kwargs.setdefault("aggr", "add") super(GATv2Conv, self).__init__(**kwargs) self.out_features_dim = out_features_dim self.heads = heads self.concat = concat self.negative_slope = negative_slope # Weights self.W_src = Linear(node_features_dim, heads * out_features_dim, bias=False) self.W_dst = Linear(node_features_dim, heads * out_features_dim, bias=False) # Attention is applied in message() during message passing stage. glorot_init = nn.init.glorot_uniform() self.att = glorot_init(mx.zeros((1, heads, out_features_dim))) if bias: bias_shape = (heads * out_features_dim) if concat else (out_features_dim) self.bias = mx.zeros(bias_shape) if dropout > 0.0: self.dropout = nn.Dropout(dropout) if edge_features_dim is not None: self.edge_lin_proj = Linear( edge_features_dim, heads * out_features_dim, bias=False ) self.edge_att = glorot_init(mx.zeros((1, heads, out_features_dim)))
[docs] def __call__( self, edge_index: mx.array, node_features: mx.array, edge_features: Optional[mx.array] = None, ) -> mx.array: """Computes the forward pass of GATConv. Args: edge_index: input edge index of shape `[2, num_edges]` node_features: input node features edge_features: edge features Returns: mx.array: computed node embeddings """ H, C = self.heads, self.out_features_dim src_feats = self.W_src(node_features).reshape(-1, H, C) dst_feats = self.W_dst(node_features).reshape(-1, H, C) src, dst = get_src_dst_features(edge_index, (src_feats, dst_feats)) dst_idx = edge_index[1] node_features = self.propagate( node_features=(src_feats, dst_feats), edge_index=edge_index, message_kwargs={ "src": src, "dst": dst, "index": dst_idx, "edge_features": edge_features, }, ) if self.concat: node_features = node_features.reshape( -1, self.heads * self.out_features_dim ) else: node_features = mx.mean(node_features, axis=1) if "bias" in self: node_features = node_features + self.bias return node_features
[docs] def message( self, src_features: mx.array, dst_features: mx.array, src: mx.array, dst: mx.array, index: mx.array, edge_features: Optional[mx.array] = None, ) -> mx.array: # Collect alpha before applying non-linearity alpha = src + dst if edge_features is not None: alpha_edge = self._compute_edge_features(edge_features) alpha = alpha + alpha_edge alpha = nn.leaky_relu(alpha, self.negative_slope) # Apply attention after non-linearity to get dynamic attention alpha = (alpha * self.att).sum(-1) alpha = scatter(alpha, index, self.num_nodes, aggr="softmax") if "dropout" in self: alpha = self.dropout(alpha) return mx.expand_dims(alpha, -1) * src_features
def _compute_edge_features(self, edge_features: mx.array): assert all(layer in self for layer in ["edge_lin_proj", "edge_att"]), """Using edge features, GATConv layer should be provided argument `edge_features_dim`.""" if edge_features.ndim == 1: edge_features = edge_features.reshape(-1, 1) edge_features = self.edge_lin_proj(edge_features) edge_features = edge_features.reshape(-1, self.heads, self.out_features_dim) return edge_features