

class mlx_graphs.nn.GINConv(mlp: mlx.nn.Module, eps: float = 0.0, learn_eps: bool = False, node_features_dim: int | None = None, edge_features_dim: int | None = None, **kwargs)[source]#

Bases: MessagePassing

Graph Isomorphism Network convolution layer from “How Powerful are Graph Neural Networks?” paper.

\[\mathbf{h}_i = \text{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]

where \(\mathbf{x}_i\) and \(\mathbf{h}_i\) represent the input features and output embeddings of node \(i\), respectively. \(\text{MLP}\) denotes a custom neural network provided by the user and \(\epsilon\) is an epsilon value either fixed or learned.

Setting edge_features_dim produces a GINEConv model, where edge_features are expected to be passed in the forward. In this case, edge features are first projected onto the same dimension as node embeddings and are summed, then passed to a relu activation. To use GINEConv, setting node_features_dim is also required.

  • mlp (Module) – Callable mlx.core.nn.Module applied on the final node embeddings

  • eps (float) – Initial value of the \(\epsilon\) term. Default: 0

  • learn_eps (bool) – Whether to learn \(\epsilon\) or not. Default False

  • edge_features_dim (Optional[int]) – Size of the edge features passed in the GINE layer

  • node_features_dim (Optional[int]) – Size of the node features (only required if GINE is used)


import mlx.core as mx
import mlx.nn as nn
from mlx_graphs.nn import GINConv

node_feat_dim = 16
edge_feat_dim = 10
out_feat_dim = 32

mlp = nn.Sequential(
    nn.Linear(node_feat_dim, node_feat_dim * 2),
    nn.Linear(node_feat_dim * 2, out_feat_dim),

# original GINConv for node features

conv = GINConv(mlp)

edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]])
node_features = mx.random.uniform(low=0, high=1, shape=(5, 16))

>>> conv(edge_index, node_features)
array([[-0.536501, 0.154826, 0.745569, ..., 0.31547, -0.0962588, -0.108504],
    [-0.415889, -0.0498145, 0.597379, ..., 0.194553, -0.251498, -0.207561],
    [-0.119966, -0.0159533, 0.276559, ..., 0.0258303, -0.194533, -0.15515],
    [-0.21477, -0.169684, 0.485867, ..., 0.0194768, -0.145761, -0.139433],
    [-0.133289, -0.0279559, 0.358095, ..., -0.0443346, -0.11571, -0.114396]],

# GINEConv including edge features:

conv = GINConv(mlp, edge_features_dim=edge_feat_dim,
edge_features = mx.random.uniform(low=0, high=1, shape=(5, edge_feat_dim))

>>> conv(edge_index, node_features, edge_features)
array([[-0.175581, 0.67481, -0.260592, ..., -1.13234, -0.631736, 0.572239],
    [0.0536669, 0.496115, -0.319334, ..., -1.165, -0.573817, 0.495315],
    [-0.0505168, 0.102068, 0.0221924, ..., -0.516901, -0.331266, 0.317491],
    [-0.00632942, 0.433597, -0.162906, ..., -0.957552, -0.41922, 0.670711],
    [-0.119726, 0.173545, 0.0951687, ..., -0.577839, -0.244039, 0.399055]],
__call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_features: mlx.core.array | None = None, edge_weights: mlx.core.array | None = None) mlx.core.array[source]#

Computes the forward pass of GINConv.

  • edge_index (array) – Input edge index of shape [2, num_edges]

  • node_features (array) – Input node features

  • edge_features (Optional[array]) – Input edge features. Defautl: None

  • edge_weights (Optional[array]) – Edge weights leveraged in message passing. Default: None

Return type:



The computed node embeddings


message(src_features, dst_features, **kwargs)

Computes messages between connected nodes.

message(src_features: mlx.core.array, dst_features: mlx.core.array, **kwargs) mlx.core.array[source]#

Computes messages between connected nodes.

By default, returns the features of source nodes. Optional edge_weights can be directly integrated in kwargs

  • src_features (array) – Source node embeddings

  • dst_features (array) – Destination node embeddings

  • edge_weights – Array of scalars with shape (num_edges,) or (num_edges, 1) used to weigh neighbor features during aggregation. Default: None

  • **kwargs – Optional args to compute messages

Return type:
