mlx_graphs.nn.GINConv#
- class mlx_graphs.nn.GINConv(mlp: mlx.nn.Module, eps: float = 0.0, learn_eps: bool = False, node_features_dim: int | None = None, edge_features_dim: int | None = None, **kwargs)[source]#
Bases:
MessagePassing
Graph Isomorphism Network convolution layer from “How Powerful are Graph Neural Networks?” paper.
\[\mathbf{h}_i = \text{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]where \(\mathbf{x}_i\) and \(\mathbf{h}_i\) represent the input features and output embeddings of node \(i\), respectively. \(\text{MLP}\) denotes a custom neural network provided by the user and \(\epsilon\) is an epsilon value either fixed or learned.
Setting
edge_features_dim
produces a GINEConv model, where edge_features are expected to be passed in the forward. In this case, edge features are first projected onto the same dimension as node embeddings and are summed, then passed to a relu activation. To use GINEConv, setting node_features_dim is also required.- Parameters:
mlp (
Module
) – Callablemlx.core.nn.Module
applied on the final node embeddingseps (
float
) – Initial value of the \(\epsilon\) term. Default:0
learn_eps (
bool
) – Whether to learn \(\epsilon\) or not. DefaultFalse
edge_features_dim (
Optional
[int
]) – Size of the edge features passed in the GINE layernode_features_dim (
Optional
[int
]) – Size of the node features (only required if GINE is used)
Example:
import mlx.core as mx import mlx.nn as nn from mlx_graphs.nn import GINConv node_feat_dim = 16 edge_feat_dim = 10 out_feat_dim = 32 mlp = nn.Sequential( nn.Linear(node_feat_dim, node_feat_dim * 2), nn.ReLU(), nn.Linear(node_feat_dim * 2, out_feat_dim), ) # original GINConv for node features conv = GINConv(mlp) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) node_features = mx.random.uniform(low=0, high=1, shape=(5, 16)) >>> conv(edge_index, node_features) array([[-0.536501, 0.154826, 0.745569, ..., 0.31547, -0.0962588, -0.108504], [-0.415889, -0.0498145, 0.597379, ..., 0.194553, -0.251498, -0.207561], [-0.119966, -0.0159533, 0.276559, ..., 0.0258303, -0.194533, -0.15515], [-0.21477, -0.169684, 0.485867, ..., 0.0194768, -0.145761, -0.139433], [-0.133289, -0.0279559, 0.358095, ..., -0.0443346, -0.11571, -0.114396]], dtype=float32) # GINEConv including edge features: conv = GINConv(mlp, edge_features_dim=edge_feat_dim, node_features_dim=node_feat_dim) edge_features = mx.random.uniform(low=0, high=1, shape=(5, edge_feat_dim)) >>> conv(edge_index, node_features, edge_features) array([[-0.175581, 0.67481, -0.260592, ..., -1.13234, -0.631736, 0.572239], [0.0536669, 0.496115, -0.319334, ..., -1.165, -0.573817, 0.495315], [-0.0505168, 0.102068, 0.0221924, ..., -0.516901, -0.331266, 0.317491], [-0.00632942, 0.433597, -0.162906, ..., -0.957552, -0.41922, 0.670711], [-0.119726, 0.173545, 0.0951687, ..., -0.577839, -0.244039, 0.399055]], dtype=float32)
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_features: mlx.core.array | None = None, edge_weights: mlx.core.array | None = None) mlx.core.array [source]#
Computes the forward pass of GINConv.
- Parameters:
- Return type:
- Returns:
The computed node embeddings
Methods
message
(src_features, dst_features, **kwargs)Computes messages between connected nodes.
- message(src_features: mlx.core.array, dst_features: mlx.core.array, **kwargs) mlx.core.array [source]#
Computes messages between connected nodes.
By default, returns the features of source nodes. Optional
edge_weights
can be directly integrated inkwargs
- Parameters:
- Return type: