mlx_graphs.nn.MessagePassing#
- class mlx_graphs.nn.MessagePassing(aggr: Literal['add', 'max', 'mean', 'softmax', 'min'] = 'add')[source]#
Bases:
Module
Base class for creating Message Passing Neural Networks (MPNNs) [1].
Inherit this class to build arbitrary GNN models based on the message passing paradigm. This implementation is inspired from PyTorch Geometric [2].
- Parameters:
aggr (
Literal
['add'
,'max'
,'mean'
,'softmax'
,'min'
]) – Aggregation strategy used to aggregate messages
References
[1] Gilmer et al. Neural Message Passing for Quantum Chemistry.
[2] Fey et al. PyG
- __call__(node_features: mlx.core.array, edge_index: mlx.core.array, **kwargs: Any)[source]#
Call self as a function.
Methods
aggregate
(messages, indices, **kwargs)Aggregates the messages using the self.aggr strategy.
message
(src_features, dst_features, **kwargs)Computes messages between connected nodes.
propagate
(edge_index, node_features[, ...])Computes messages from neighbors, aggregates them and updates the final node embeddings.
update_nodes
(aggregated, **kwargs)Updates the final embeddings given the aggregated messages.
- aggregate(messages: mlx.core.array, indices: mlx.core.array, **kwargs) mlx.core.array [source]#
Aggregates the messages using the self.aggr strategy.
- message(src_features: mlx.core.array, dst_features: mlx.core.array, **kwargs) mlx.core.array [source]#
Computes messages between connected nodes.
By default, returns the features of source nodes. Optional
edge_weights
can be directly integrated inkwargs
- Parameters:
- Return type:
- propagate(edge_index: mlx.core.array, node_features: mlx.core.array | Tuple[mlx.core.array, mlx.core.array], message_kwargs: Dict | None = {}, aggregate_kwargs: Dict | None = {}, update_kwargs: Dict | None = {}) mlx.core.array [source]#
Computes messages from neighbors, aggregates them and updates the final node embeddings.
- Parameters:
edge_index (
array
) – Graph representation of shape [2, num_edges]node_features (
Union
[array
,Tuple
[array
,array
]]) – Input node features/embeddings. Can be either an array or a tuple of arrays, for distinct src and dst node features.message_kwargs (
Optional
[Dict
]) – Arguments to pass to the message methodaggregate_kwargs (
Optional
[Dict
]) – Arguments to pass to the aggregate methodupdate_kwargs (
Optional
[Dict
]) – Arguments to pass to the update_nodes method
- Return type:
- update_nodes(aggregated: mlx.core.array, **kwargs) mlx.core.array [source]#
Updates the final embeddings given the aggregated messages.