mlx_graphs.nn.GeneralizedRelationalConv#

class mlx_graphs.nn.GeneralizedRelationalConv(in_features_dim: int, out_features_dim: int, num_relations: int, message_func: Literal['transe', 'distmult', 'rotate'] = 'distmult', aggregate_func: Literal['add', 'mean', 'pna'] = 'add', layer_norm: bool = True, activation: str = 'relu', dependent: bool = False, node_dim: int = 0, **kwargs)[source]#

Bases: MessagePassing

Generalized relational convolution layer from Neural Bellman-Ford Networks: A General Graph Neural Network Framework for Link Prediction paper.

Adopted from the PyG version from here.

Part of the Neural Bellman-Ford networks (NBFNet) holding state-of-the-art in KG completion. Works with multi-relational graphs where edge types are stored in edge_labels. The message function composes node and relation vectors in three possible ways. The expected behavior is to work with “labeling trick” graphs where one node in the graph is labeled with a query vector, while rest are zeros. Message passing is then done separately for each data point in the batch. The input shape is expected to be [batch_size, num_nodes, input_dim]

Alternatively, the layer can work as a standard relational conv with shapes [num_nodes, input_dim].

Note that this implementation materializes all edge messages and is O(E). The complexity can be further reduced by adopting the O(V) rspmm C++ kernel from the NBFNet-PyG repo to the MLX framework (not implemented here).

Parameters:
  • in_features_dim (int) – input feature dimension (same for node and edge features)

  • out_features_dim (int) – output node feature dimension

  • num_relations (int) – number of unique relations in the graph

  • message_func (Literal['transe', 'distmult', 'rotate']) – “transe” (sum), “distmult” (mult), “rotate” (complex rotation). Default: distmult

  • aggregate_func (Literal['add', 'mean', 'pna']) – “add”, “mean”, or “pna”. Default: add

  • layer_norm (bool) – whether to use layer norm (often crucial to the performance). Default: True

  • activation (str) – non-linearity. Default: relu

  • dependent (bool) – whether to use separate relation embedding matrix False or build relations from the input relations True

  • node_dim (int) – for 3D batches, specified which dimension contains all nodes. Default: 0

Example:

import mlx.core as mx
from mlx_graphs.nn import GeneralizedRelationalConv

input_dim = 16
output_dim = 16
num_relations = 3

conv = GeneralizedRelationalConv(input_dim, output_dim, num_relations)

batch_size = 2
edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]])
edge_types = mx.array([0, 0, 1, 1, 2])
boundary = mx.random.uniform(0, 1, shape=(batch_size, 5, 16))
size = (boundary.shape[1], boundary.shape[1])

layer_input = boundary  # initial node features which will be updated
h = conv(edge_index, layer_input, edge_types, boundary, size=size)

# optional: residual connection if input dim == output dim
h = h + layer_input
layer_input = h

# another conv type where relations are obtained from the additional
# query tensor
query = mx.random.uniform(0, 1, shape=(batch_size, 16))
conv2 = GeneralizedRelationalConv(
    input_dim, output_dim, num_relations, dependent=True)

h = conv2(edge_index, layer_input, edge_types, boundary, query, size=size)
__call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_type: mlx.core.array, boundary: mlx.core.array, query: mlx.core.array | None = None, size: tuple[int, int] | None = None, edge_weights: mlx.core.array | None = None, **kwargs) mlx.core.array[source]#

Computes the forward pass of GeneralizedRelationalConv.

Parameters:
  • edge_index (array) – Input edge index of shape [2, num_edges]

  • node_features (array) – Input node features, shape [bs, num_nodes, dim] or [num_nodes, dim]

  • edge_type (array) – Input edge types of shape [num_edges,]

  • boundary (array) – Initial node feats [bs, num_nodes, dim] or [num_nodes, dim]

  • query (Optional[array]) – Optional input node queries, shape [bs, dim]

  • size (Optional[tuple[int, int]]) – a tuple encoding the size of the graph eg (5, 5)

  • edge_weights (Optional[array]) – Edge weights leveraged in message passing. Default: None

Return type:

array

Returns:

The computed node embeddings

Methods

aggregate(messages, indices, edge_weights, ...)

Aggregates the messages using the self.aggr strategy.

message(src_features, dst_features, ...)

Computes messages between connected nodes.

update_nodes(aggregated, old)

Updates the final embeddings given the aggregated messages.

aggregate(messages: mlx.core.array, indices: mlx.core.array, edge_weights: mlx.core.array, dim_size: tuple[int, int]) mlx.core.array[source]#

Aggregates the messages using the self.aggr strategy.

Parameters:
  • messages (array) – Computed messages

  • indices (array) – Indices representing the nodes that receive messages

  • **kwargs – Optional args to aggregate messages

Return type:

array

message(src_features: mlx.core.array, dst_features: mlx.core.array, relation: mlx.core.array, boundary: mlx.core.array, edge_type: mlx.core.array) mlx.core.array[source]#

Computes messages between connected nodes.

By default, returns the features of source nodes. Optional edge_weights can be directly integrated in kwargs

Parameters:
  • src_features (array) – Source node embeddings

  • dst_features (array) – Destination node embeddings

  • edge_weights – Array of scalars with shape (num_edges,) or (num_edges, 1) used to weigh neighbor features during aggregation. Default: None

  • **kwargs – Optional args to compute messages

Return type:

array

update_nodes(aggregated: mlx.core.array, old: mlx.core.array) mlx.core.array[source]#

Updates the final embeddings given the aggregated messages.

Parameters:
  • aggregated (array) – aggregated messages

  • **kwargs – optional args to update messages

Return type:

array