mlx_graphs.nn.SimpleConv

mlx_graphs.nn.SimpleConv#

class mlx_graphs.nn.SimpleConv(aggr: Literal['add', 'max', 'mean', 'softmax', 'min'] = 'add', combine_root_func: Literal['sum', 'cat', 'self_loop'] | None = None, **kwargs)[source]#

Bases: MessagePassing

A simple non-trainable message passing layer.

\[\mathbf{x}^{\prime}_i = \bigoplus_{j \in \mathcal{N(i)}} e_{j,i} \cdot \mathbf{x}_j\]

where \(\bigoplus\) denotes an aggregation strategy (e.g. 'add', 'mean'), and \(e_{j,i}\) denotes the edge weight between the source node \(j\) and the target node \(i\) and \(\mathcal{N(i)}\) denotes the neighbors of node \(i\) and \(\mathbf{x}_j\) denotes the features of node \(j\).

Inspired by the SimpleConv PyG layer.

Parameters:
  • aggr (Literal['add', 'max', 'mean', 'softmax', 'min']) – Aggregation strategy used to aggregate messages, e.g. 'add', 'mean', 'max'. Default: 'add'

  • combine_root_func (Optional[Literal['sum', 'cat', 'self_loop']]) – Strategy used to combine the features from the root nodes. Available values: 'sum', 'cat', 'self_loop' or None). 'sum': It sums up the neighborhood’s message and root node’s features. 'cat': It concatenates neihborhood’s message and root node’s features. 'self_loop': It adds a self-loop for each root node and aggregates the messages. If the graph is weighted then the edge weights of self-loops will be set to 1. Default: None

Example:

import mlx.core as mx
from mlx_graphs.nn import SimpleConv

# Sum the messages from the neighbors.
# Use a self-loop for each root node.
conv = SimpleConv(aggr="add", combine_root_func="self_loop")
node_features = mx.ones((5, 3))
edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]])
edge_weights = mx.array([10, 20, 5, 2, 15])
h = conv(edge_index, node_features, edge_weights)
__call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_weights: mlx.core.array | None = None) mlx.core.array[source]#

Computes the forward pass of SimpleConv.

Parameters:
  • edge_index (array) – Input edge index of shape [2, num_edges]

  • node_features (array) – Input node features

  • edge_weights (Optional[array]) – Edge weights leveraged in message passing. Default: None

Return type:

array

Returns:

The computed node embeddings

Methods