Source code for mlx_graphs.nn.message_passing

from typing import Any, Dict, Optional, Tuple, Union, get_args

import mlx.core as mx
import mlx.nn as nn

from mlx_graphs.utils import ScatterAggregations, get_src_dst_features, scatter


[docs] class MessagePassing(nn.Module): """Base class for creating Message Passing Neural Networks (MPNNs) [1]. Inherit this class to build arbitrary GNN models based on the message passing paradigm. This implementation is inspired from PyTorch Geometric [2]. Args: aggr: Aggregation strategy used to aggregate messages References: [1] `Gilmer et al. Neural Message Passing for Quantum Chemistry.\ <https://proceedings.mlr.press/v70/gilmer17a/gilmer17a.pdf>`_ [2] `Fey et al. PyG <https://pytorch-geometric.readthedocs.io/en/latest/\ generated/torch_geometric.nn.conv.MessagePassing.html>`_ """ def __init__(self, aggr: ScatterAggregations = "add"): super().__init__() if aggr not in get_args(ScatterAggregations): raise ValueError( "Invalid aggregation function.", f"Available values are {get_args(ScatterAggregations)}", ) self.aggr: ScatterAggregations = aggr self.num_nodes = None
[docs] def __call__(self, node_features: mx.array, edge_index: mx.array, **kwargs: Any): raise NotImplementedError
[docs] def propagate( self, edge_index: mx.array, node_features: Union[mx.array, Tuple[mx.array, mx.array]], message_kwargs: Optional[Dict] = {}, aggregate_kwargs: Optional[Dict] = {}, update_kwargs: Optional[Dict] = {}, ) -> mx.array: """Computes messages from neighbors, aggregates them and updates the final node embeddings. Args: edge_index: Graph representation of shape `[2, num_edges]` node_features: Input node features/embeddings. Can be either an array or a tuple of arrays, for distinct src and dst node features. message_kwargs: Arguments to pass to the `message` method aggregate_kwargs: Arguments to pass to the `aggregate` method update_kwargs: Arguments to pass to the `update_nodes` method """ if not (isinstance(edge_index, mx.array) and edge_index.shape[0] == 2): raise ValueError("Edge index should be an array with shape (2, |E|)") if isinstance(node_features, tuple): if len(node_features) != 2 or not all( isinstance(array, mx.array) for array in node_features ): raise ValueError( "Invalid shape for `node_features`, should be a tuple of 2 mx.array" ) else: if not isinstance(node_features, mx.array): raise ValueError( "Invalid shape for `node_features`, should be an `mx.array`,", f"found {type(node_features)}", ) src_features, dst_features = get_src_dst_features(edge_index, node_features) self.num_nodes = ( node_features if isinstance(node_features, mx.array) else node_features[0] ).shape[0] dst_idx = edge_index[1] # shapes: (|E| -> |E|) messages = self.message( src_features=src_features, dst_features=dst_features, **message_kwargs, ) # shapes: (|E| -> |N|) aggregated = self.aggregate( messages=messages, indices=dst_idx, **aggregate_kwargs, ) # shapes: (|N| -> |N|) output = self.update_nodes( aggregated=aggregated, **update_kwargs, ) return output
[docs] def message( self, src_features: mx.array, dst_features: mx.array, **kwargs ) -> mx.array: """Computes messages between connected nodes. By default, returns the features of source nodes. Optional ``edge_weights`` can be directly integrated in ``kwargs`` Args: src_features: Source node embeddings dst_features: Destination node embeddings edge_weights: Array of scalars with shape (num_edges,) or (num_edges, 1) used to weigh neighbor features during aggregation. Default: ``None`` **kwargs: Optional args to compute messages """ edge_weights = kwargs.get("edge_weights", None) return ( src_features if edge_weights is None else edge_weights.reshape(-1, 1) * src_features )
[docs] def aggregate(self, messages: mx.array, indices: mx.array, **kwargs) -> mx.array: """Aggregates the messages using the `self.aggr` strategy. Args: messages: Computed messages indices: Indices representing the nodes that receive messages **kwargs: Optional args to aggregate messages """ return scatter(messages, indices, self.num_nodes, self.aggr)
[docs] def update_nodes(self, aggregated: mx.array, **kwargs) -> mx.array: """Updates the final embeddings given the aggregated messages. Args: aggregated: aggregated messages **kwargs: optional args to update messages """ return aggregated