mlx_graphs.nn.SimpleConv#
- class mlx_graphs.nn.SimpleConv(aggr: Literal['add', 'max', 'mean', 'softmax', 'min'] = 'add', combine_root_func: Literal['sum', 'cat', 'self_loop'] | None = None, **kwargs)[source]#
Bases:
MessagePassingA simple non-trainable message passing layer.
\[\mathbf{x}^{\prime}_i = \bigoplus_{j \in \mathcal{N(i)}} e_{j,i} \cdot \mathbf{x}_j\]where \(\bigoplus\) denotes an aggregation strategy (e.g.
'add','mean'), and \(e_{j,i}\) denotes the edge weight between the source node \(j\) and the target node \(i\) and \(\mathcal{N(i)}\) denotes the neighbors of node \(i\) and \(\mathbf{x}_j\) denotes the features of node \(j\).Inspired by the SimpleConv PyG layer.
- Parameters:
aggr (
Literal['add','max','mean','softmax','min']) – Aggregation strategy used to aggregate messages, e.g.'add','mean','max'. Default:'add'combine_root_func (
Optional[Literal['sum','cat','self_loop']]) – Strategy used to combine the features from the root nodes. Available values:'sum','cat','self_loop'orNone).'sum': It sums up the neighborhood’s message and root node’s features.'cat': It concatenates neihborhood’s message and root node’s features.'self_loop': It adds a self-loop for each root node and aggregates the messages. If the graph is weighted then the edge weights of self-loops will be set to1. Default:None
Example:
import mlx.core as mx from mlx_graphs.nn import SimpleConv # Sum the messages from the neighbors. # Use a self-loop for each root node. conv = SimpleConv(aggr="add", combine_root_func="self_loop") node_features = mx.ones((5, 3)) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) edge_weights = mx.array([10, 20, 5, 2, 15]) h = conv(edge_index, node_features, edge_weights)
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_weights: mlx.core.array | None = None) mlx.core.array[source]#
Computes the forward pass of SimpleConv.
Methods