mlx_graphs.nn.SAGEConv

mlx_graphs.nn.SAGEConv#

class mlx_graphs.nn.SAGEConv(node_features_dim: int, out_features_dim: int, bias: bool = True, **kwargs)[source]#

Bases: MessagePassing

GraphSAGE convolution layer from “Inductive Representation Learning on Large Graphs” paper.

\[\mathbf{h}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \bigoplus_{j \in \mathcal{N}(i)} \mathbf{x}_j,\]

where \(\mathbf{x}_i\) represents the input features of node \(i\), \(\bigoplus\) denotes the aggregation function, set by default to mean and \(\mathbf{h}_i\) is the computed embedding of node \(i\).

Parameters:
  • node_features_dim (int) – Size of input node features

  • out_features_dim (int) – Size of output node embeddings

  • bias (bool) – Whether to use bias in the node projection

Example:

import mlx.core as mx
from mlx_graphs.data.data import GraphData
from mlx_graphs.nn import SAGEConv

graph = GraphData(
    edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]),
    node_features = mx.ones((5, 16)),
)

conv = SAGEConv(16, 32)
h = conv(graph.edge_index, graph.node_features)

>>> h
array([[1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938],
    [1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938],
    [1.05823, -0.295776, 0.075439, ..., -0.104383, 0.031947, -0.351897],
    [1.65429, -0.376169, 1.04172, ..., -0.919106, 1.42576, 0.490938],
    [1.05823, -0.295776, 0.075439, ..., -0.104383, 0.031947, -0.351897]],
    dtype=float32)
__call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_weights: mlx.core.array | None = None) mlx.core.array[source]#

Computes the forward pass of SAGEConv.

Parameters:
  • edge_index (array) – Input edge index of shape [2, num_edges]

  • node_features (array) – Input node features

  • edge_weights (Optional[array]) – Edge weights leveraged in message passing. Default: None

Returns:

The computed node embeddings

Return type:

mx.array

Methods