Source code for mlx_graphs.nn.conv.simple_conv

from typing import Literal, Optional, get_args

import mlx.core as mx

from mlx_graphs.nn.message_passing import MessagePassing
from mlx_graphs.utils import ScatterAggregations, add_self_loops

CombineRootFunctions = Literal["sum", "cat", "self_loop"]

[docs] class SimpleConv(MessagePassing): """A simple non-trainable message passing layer. .. math:: \\mathbf{x}^{\\prime}_i = \\bigoplus_{j \\in \\mathcal{N(i)}} e_{j,i} \\cdot \\mathbf{x}_j where :math:`\\bigoplus` denotes an aggregation strategy (e.g. ``'add'``,\ ``'mean'``), and :math:`e_{j,i}` denotes the edge weight between the source\ node :math:`j` and the target node :math:`i` and :math:`\\mathcal{N(i)}` denotes\ the neighbors of node :math:`i` and :math:`\\mathbf{x}_j` denotes the features of\ node :math:`j`. Inspired by the `SimpleConv PyG layer\ <>`_. Args: aggr: Aggregation strategy used to aggregate messages, e.g. ``'add'``, ``'mean'``, ``'max'``. Default: ``'add'`` combine_root_func: Strategy used to combine the features from the root nodes.\ Available values: ``'sum'``, ``'cat'``, ``'self_loop'`` or ``None``).\ ``'sum'``: It sums up the neighborhood's message and root node's features.\ ``'cat'``: It concatenates neihborhood's message and root node's features.\ ``'self_loop'``: It adds a self-loop for each root node and aggregates the\ messages. If the graph is weighted then the edge weights of self-loops will\ be set to :obj:`1`. Default: ``None`` Example: .. code-block:: python import mlx.core as mx from mlx_graphs.nn import SimpleConv # Sum the messages from the neighbors. # Use a self-loop for each root node. conv = SimpleConv(aggr="add", combine_root_func="self_loop") node_features = mx.ones((5, 3)) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) edge_weights = mx.array([10, 20, 5, 2, 15]) h = conv(edge_index, node_features, edge_weights) """ def __init__( self, aggr: ScatterAggregations = "add", combine_root_func: Optional[CombineRootFunctions] = None, **kwargs, ): super(SimpleConv, self).__init__(aggr, **kwargs) if combine_root_func is not None and combine_root_func not in get_args( CombineRootFunctions ): raise ValueError( "Invalid combine_root_func.", f"Available values are {get_args(CombineRootFunctions)}", ) self.combine_root_func = combine_root_func
[docs] def __call__( self, edge_index: mx.array, node_features: mx.array, edge_weights: Optional[mx.array] = None, ) -> mx.array: """Computes the forward pass of SimpleConv. Args: edge_index: Input edge index of shape `[2, num_edges]` node_features: Input node features edge_weights: Edge weights leveraged in message passing. Default: ``None`` Returns: The computed node embeddings """ # Add self-loops, if needed. if self.combine_root_func == "self_loop": # Edge weights exist. if edge_weights is not None: edge_index, edge_weights = add_self_loops( edge_index, edge_weights.reshape(-1, 1) ) edge_weights = edge_weights.reshape(-1) # Edge weights do not exist. else: edge_index = add_self_loops(edge_index) # Compute messages and aggregate them. output = self.propagate( edge_index=edge_index, node_features=node_features, message_kwargs={"edge_weights": edge_weights}, ) # Combine the root node features. if self.combine_root_func is not None: if self.combine_root_func == "sum": output = output + node_features elif self.combine_root_func == "cat": output = mx.concatenate([node_features, output], axis=-1) return output