mlx_graphs.nn.SimpleConv#
- class mlx_graphs.nn.SimpleConv(aggr: Literal['add', 'max', 'mean', 'softmax', 'min'] = 'add', combine_root_func: Literal['sum', 'cat', 'self_loop'] | None = None, **kwargs)[source]#
Bases:
MessagePassing
A simple non-trainable message passing layer.
\[\mathbf{x}^{\prime}_i = \bigoplus_{j \in \mathcal{N(i)}} e_{j,i} \cdot \mathbf{x}_j\]where \(\bigoplus\) denotes an aggregation strategy (e.g.
'add'
,'mean'
), and \(e_{j,i}\) denotes the edge weight between the source node \(j\) and the target node \(i\) and \(\mathcal{N(i)}\) denotes the neighbors of node \(i\) and \(\mathbf{x}_j\) denotes the features of node \(j\).Inspired by the SimpleConv PyG layer.
- Parameters:
aggr (
Literal
['add'
,'max'
,'mean'
,'softmax'
,'min'
]) – Aggregation strategy used to aggregate messages, e.g.'add'
,'mean'
,'max'
. Default:'add'
combine_root_func (
Optional
[Literal
['sum'
,'cat'
,'self_loop'
]]) – Strategy used to combine the features from the root nodes. Available values:'sum'
,'cat'
,'self_loop'
orNone
).'sum'
: It sums up the neighborhood’s message and root node’s features.'cat'
: It concatenates neihborhood’s message and root node’s features.'self_loop'
: It adds a self-loop for each root node and aggregates the messages. If the graph is weighted then the edge weights of self-loops will be set to1
. Default:None
Example:
import mlx.core as mx from mlx_graphs.nn import SimpleConv # Sum the messages from the neighbors. # Use a self-loop for each root node. conv = SimpleConv(aggr="add", combine_root_func="self_loop") node_features = mx.ones((5, 3)) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) edge_weights = mx.array([10, 20, 5, 2, 15]) h = conv(edge_index, node_features, edge_weights)
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_weights: mlx.core.array | None = None) mlx.core.array [source]#
Computes the forward pass of SimpleConv.
Methods