mlx_graphs.nn.SAGEConv#
- class mlx_graphs.nn.SAGEConv(node_features_dim: int, out_features_dim: int, bias: bool = True, **kwargs)[source]#
Bases:
MessagePassing
GraphSAGE convolution layer from “Inductive Representation Learning on Large Graphs” paper.
\[\mathbf{h}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \bigoplus_{j \in \mathcal{N}(i)} \mathbf{x}_j,\]where \(\mathbf{x}_i\) represents the input features of node \(i\), \(\bigoplus\) denotes the aggregation function, set by default to mean and \(\mathbf{h}_i\) is the computed embedding of node \(i\).
- Parameters:
Example:
import mlx.core as mx from mlx_graphs.data.data import GraphData from mlx_graphs.nn import SAGEConv graph = GraphData( edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]), node_features = mx.ones((5, 16)), ) conv = SAGEConv(16, 32) h = conv(graph.edge_index, graph.node_features) >>> h array([[1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938], [1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938], [1.05823, -0.295776, 0.075439, ..., -0.104383, 0.031947, -0.351897], [1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938], [1.05823, -0.295776, 0.075439, ..., -0.104383, 0.031947, -0.351897]], dtype=float32)
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_weights: mlx.core.array | None = None) mlx.core.array [source]#
Computes the forward pass of SAGEConv.
Methods