mlx_graphs.nn.MessagePassing#

class mlx_graphs.nn.MessagePassing(aggr: Literal['add', 'max', 'mean', 'softmax', 'min'] = 'add')[source]#

Bases: Module

Base class for creating Message Passing Neural Networks (MPNNs) [1].

Inherit this class to build arbitrary GNN models based on the message passing paradigm. This implementation is inspired from PyTorch Geometric [2].

Parameters:

aggr (Literal['add', 'max', 'mean', 'softmax', 'min']) – Aggregation strategy used to aggregate messages

References

[1] Gilmer et al. Neural Message Passing for Quantum Chemistry.

[2] Fey et al. PyG

__call__(node_features: mlx.core.array, edge_index: mlx.core.array, **kwargs: Any)[source]#

Call self as a function.

Methods

aggregate(messages, indices, **kwargs)

Aggregates the messages using the self.aggr strategy.

message(src_features, dst_features, **kwargs)

Computes messages between connected nodes.

propagate(edge_index, node_features[, ...])

Computes messages from neighbors, aggregates them and updates the final node embeddings.

update_nodes(aggregated, **kwargs)

Updates the final embeddings given the aggregated messages.

aggregate(messages: mlx.core.array, indices: mlx.core.array, **kwargs) mlx.core.array[source]#

Aggregates the messages using the self.aggr strategy.

Parameters:
  • messages (array) – Computed messages

  • indices (array) – Indices representing the nodes that receive messages

  • **kwargs – Optional args to aggregate messages

Return type:

array

message(src_features: mlx.core.array, dst_features: mlx.core.array, **kwargs) mlx.core.array[source]#

Computes messages between connected nodes.

By default, returns the features of source nodes. Optional edge_weights can be directly integrated in kwargs

Parameters:
  • src_features (array) – Source node embeddings

  • dst_features (array) – Destination node embeddings

  • edge_weights – Array of scalars with shape (num_edges,) or (num_edges, 1) used to weigh neighbor features during aggregation. Default: None

  • **kwargs – Optional args to compute messages

Return type:

array

propagate(edge_index: mlx.core.array, node_features: mlx.core.array | Tuple[mlx.core.array, mlx.core.array], message_kwargs: Dict | None = {}, aggregate_kwargs: Dict | None = {}, update_kwargs: Dict | None = {}) mlx.core.array[source]#

Computes messages from neighbors, aggregates them and updates the final node embeddings.

Parameters:
  • edge_index (array) – Graph representation of shape [2, num_edges]

  • node_features (Union[array, Tuple[array, array]]) – Input node features/embeddings. Can be either an array or a tuple of arrays, for distinct src and dst node features.

  • message_kwargs (Optional[Dict]) – Arguments to pass to the message method

  • aggregate_kwargs (Optional[Dict]) – Arguments to pass to the aggregate method

  • update_kwargs (Optional[Dict]) – Arguments to pass to the update_nodes method

Return type:

array

update_nodes(aggregated: mlx.core.array, **kwargs) mlx.core.array[source]#

Updates the final embeddings given the aggregated messages.

Parameters:
  • aggregated (array) – aggregated messages

  • **kwargs – optional args to update messages

Return type:

array