Source code for mlx_graphs.nn.conv.sage_conv

from typing import Optional

import mlx.core as mx

from mlx_graphs.nn.linear import Linear
from mlx_graphs.nn.message_passing import MessagePassing

[docs] class SAGEConv(MessagePassing): r"""GraphSAGE convolution layer from `"Inductive Representation Learning on Large Graphs" <>`_ paper. .. math:: \mathbf{h}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \bigoplus_{j \in \mathcal{N}(i)} \mathbf{x}_j, where :math:`\mathbf{x}_i` represents the input features of node :math:`i`, :math:`\bigoplus` denotes the aggregation function, set by default to `mean` and :math:`\mathbf{h}_i` is the computed embedding of node :math:`i`. Args: node_features_dim: Size of input node features out_features_dim: Size of output node embeddings bias: Whether to use bias in the node projection Example: .. code-block:: python import mlx.core as mx from import GraphData from mlx_graphs.nn import SAGEConv graph = GraphData( edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]), node_features = mx.ones((5, 16)), ) conv = SAGEConv(16, 32) h = conv(graph.edge_index, graph.node_features) >>> h array([[1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938], [1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938], [1.05823, -0.295776, 0.075439, ..., -0.104383, 0.031947, -0.351897], [1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938], [1.05823, -0.295776, 0.075439, ..., -0.104383, 0.031947, -0.351897]], dtype=float32) """ def __init__( self, node_features_dim: int, out_features_dim: int, bias: bool = True, **kwargs ): kwargs.setdefault("aggr", "mean") super(SAGEConv, self).__init__(**kwargs) self.node_features_dim = node_features_dim self.out_features_dim = out_features_dim self.neigh_proj = Linear(node_features_dim, out_features_dim, bias=False) self.self_proj = Linear(node_features_dim, out_features_dim, bias=bias)
[docs] def __call__( self, edge_index: mx.array, node_features: mx.array, edge_weights: Optional[mx.array] = None, ) -> mx.array: """Computes the forward pass of SAGEConv. Args: edge_index: Input edge index of shape `[2, num_edges]` node_features: Input node features edge_weights: Edge weights leveraged in message passing. Default: ``None`` Returns: mx.array: The computed node embeddings """ # We follow DGL's way by applying projection on the smaller feature dimension linear_before_mp = self.node_features_dim > self.out_features_dim if linear_before_mp: neigh_features = self.neigh_proj(node_features) neigh_features = self.propagate( edge_index=edge_index, node_features=neigh_features, message_kwargs={"edge_weights": edge_weights}, ) else: neigh_features = self.propagate( edge_index=edge_index, node_features=node_features, message_kwargs={"edge_weights": edge_weights}, ) neigh_features = self.neigh_proj(neigh_features) out_features = self.self_proj(node_features) + neigh_features return out_features