Source code for mlx_graphs.nn.conv.rel_conv

from typing import Literal, Optional, get_args

import mlx.core as mx
from mlx import nn

from mlx_graphs.nn.linear import Linear
from mlx_graphs.nn.message_passing import MessagePassing
from mlx_graphs.utils import degree, scatter

MessageFunctions = Literal["transe", "distmult", "rotate"]
AggregationFunctions = Literal["add", "mean", "pna"]


[docs] class GeneralizedRelationalConv(MessagePassing): """Generalized relational convolution layer from `Neural Bellman-Ford Networks: A General Graph Neural Network Framework \ for Link Prediction <https://arxiv.org/abs/2106.06935>`_ paper. Adopted from the PyG version from `here \ <https://github.com/KiddoZhu/NBFNet-PyG/blob/master/nbfnet/layers.py>`_. Part of the Neural Bellman-Ford networks (NBFNet) holding state-of-the-art in KG completion. Works with multi-relational graphs where edge types are stored in `edge_labels`. The message function composes node and relation vectors in three possible ways. The expected behavior is to work with "labeling trick" graphs where one node in the graph is labeled with a `query` vector, while rest are zeros. Message passing is then done separately for each data point in the batch. The input shape is expected to be [batch_size, num_nodes, input_dim] Alternatively, the layer can work as a standard relational conv with shapes [num_nodes, input_dim]. Note that this implementation materializes all edge messages and is O(E). The complexity can be further reduced by adopting the O(V) `rspmm` C++ kernel from the NBFNet-PyG repo to the MLX framework (not implemented here). Args: in_features_dim: input feature dimension (same for node and edge features) out_features_dim: output node feature dimension num_relations: number of unique relations in the graph message_func: "transe" (sum), "distmult" (mult), "rotate" (complex rotation). Default: ``distmult`` aggregate_func: "add", "mean", or "pna". Default: ``add`` layer_norm: whether to use layer norm (often crucial to the performance). Default: ``True`` activation: non-linearity. Default: ``relu`` dependent: whether to use separate relation embedding matrix ``False`` or build relations from the input relations ``True`` node_dim: for 3D batches, specified which dimension contains all nodes. Default: ``0`` Example: .. code-block:: python import mlx.core as mx from mlx_graphs.nn import GeneralizedRelationalConv input_dim = 16 output_dim = 16 num_relations = 3 conv = GeneralizedRelationalConv(input_dim, output_dim, num_relations) batch_size = 2 edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) edge_types = mx.array([0, 0, 1, 1, 2]) boundary = mx.random.uniform(0, 1, shape=(batch_size, 5, 16)) size = (boundary.shape[1], boundary.shape[1]) layer_input = boundary # initial node features which will be updated h = conv(edge_index, layer_input, edge_types, boundary, size=size) # optional: residual connection if input dim == output dim h = h + layer_input layer_input = h # another conv type where relations are obtained from the additional # query tensor query = mx.random.uniform(0, 1, shape=(batch_size, 16)) conv2 = GeneralizedRelationalConv( input_dim, output_dim, num_relations, dependent=True) h = conv2(edge_index, layer_input, edge_types, boundary, query, size=size) """ eps = 1e-6 def __init__( self, in_features_dim: int, out_features_dim: int, num_relations: int, message_func: MessageFunctions = "distmult", aggregate_func: AggregationFunctions = "add", layer_norm: bool = True, activation: str = "relu", dependent: bool = False, node_dim: int = 0, **kwargs, ): super(GeneralizedRelationalConv, self).__init__(**kwargs) if aggregate_func not in get_args(AggregationFunctions): raise ValueError( "Invalid aggregate_func.", f"Available values are {get_args(AggregationFunctions)}", ) if message_func not in get_args(MessageFunctions): raise ValueError( "Invalid message_func.", f"Available values are {get_args(MessageFunctions)}", ) self.in_features_dim = in_features_dim self.out_features_dim = out_features_dim self.num_relations = num_relations self.message_func = message_func self.aggregate_func = aggregate_func self.dependent = dependent self.node_dim = node_dim if layer_norm: self.layer_norm = nn.LayerNorm(out_features_dim) else: self.layer_norm = None if isinstance(activation, str): self.activation = getattr(nn, activation) else: self.activation = activation if self.aggregate_func == "pna": # 12 for 4 aggregations (mean, max, min, std) # and 3 scalers (identity, degree, 1/degree) # +1 for the old state, so 13 is the final multiplier self.linear = Linear(in_features_dim * 13, out_features_dim) else: self.linear = Linear(in_features_dim * 2, out_features_dim) if dependent: # obtain relation embeddings as a projection of the query relation self.relation_linear = Linear( in_features_dim, num_relations * in_features_dim ) else: # relation embeddings as an independent embedding matrix per each layer self.relation = nn.Embedding(num_relations, in_features_dim)
[docs] def __call__( self, edge_index: mx.array, node_features: mx.array, edge_type: mx.array, boundary: mx.array, query: Optional[mx.array] = None, size: Optional[tuple[int, int]] = None, edge_weights: Optional[mx.array] = None, **kwargs, ) -> mx.array: """Computes the forward pass of GeneralizedRelationalConv. Args: edge_index: Input edge index of shape `[2, num_edges]` node_features: Input node features, shape `[bs, num_nodes, dim]` or `[num_nodes, dim]` edge_type: Input edge types of shape `[num_edges,]` boundary: Initial node feats `[bs, num_nodes, dim]` or `[num_nodes, dim]` query: Optional input node queries, shape `[bs, dim]` size: a tuple encoding the size of the graph eg `(5, 5)` edge_weights: Edge weights leveraged in message passing. Default: ``None`` Returns: The computed node embeddings """ self.input_dims = len(node_features.shape) batch_size = node_features.shape[0] if self.input_dims == 3 else 1 if size is None: num_nodes = ( node_features.shape[0] if self.input_dims == 2 else node_features.shape[1] ) size = (num_nodes, num_nodes) # input: (bs, num_nodes, dim) if self.dependent: assert query is not None, "query must be supplied when dependent=True" assert ( self.input_dims == 3 ), "expected input shape is [batch_size, num_nodes, dim]" # relation features as a projection of input "query" (relation) embeddings relation = self.relation_linear(query).reshape( batch_size, self.num_relations, self.in_features_dim ) else: # relation features as an embedding matrix unique to each layer # relation: (batch_size, num_relation, dim) relation = mx.repeat(self.relation.weight[None, :], batch_size, axis=0) if edge_weights is None: edge_weights = mx.ones(len(edge_type)) # since mlx_graphs gathers always along dimension 0 (num_nodes are rows) # we have to reshape input features accordingly if self.input_dims == 3: node_features = node_features.transpose(1, 0, 2) boundary = boundary.transpose(1, 0, 2) # note that we send the initial boundary condition (node states at layer0) # to the message passing # correspond to Eq.6 on p5 in https://arxiv.org/pdf/2106.06935.pdf output = self.propagate( node_features=node_features, edge_index=edge_index, message_kwargs=dict( relation=relation, boundary=boundary, edge_type=edge_type ), aggregate_kwargs=dict(edge_weights=edge_weights, dim_size=size), update_kwargs=dict(old=node_features), ) return output
[docs] def message( self, src_features: mx.array, dst_features: mx.array, relation: mx.array, boundary: mx.array, edge_type: mx.array, ) -> mx.array: # extracting relation features relation_j = relation[:, edge_type] if self.input_dims == 3: relation_j = relation_j.transpose(1, 0, 2) else: relation_j = relation_j.squeeze(0) if self.message_func == "transe": message = src_features + relation_j elif self.message_func == "distmult": message = src_features * relation_j elif self.message_func == "rotate": x_j_re, x_j_im = src_features.split(2, axis=-1) r_j_re, r_j_im = relation_j.split(2, axis=-1) message_re = x_j_re * r_j_re - x_j_im * r_j_im message_im = x_j_re * r_j_im + x_j_im * r_j_re message = mx.concatenate([message_re, message_im], axis=-1) else: raise ValueError("Unknown message function `%s`" % self.message_func) # augment messages with the boundary condition message = mx.concatenate( [message, boundary], axis=0 ) # (num_edges + num_nodes, batch_size, input_dim) return message
[docs] def aggregate( self, messages: mx.array, indices: mx.array, edge_weights: mx.array, dim_size: tuple[int, int], ) -> mx.array: # augment aggregation index with self-loops for the boundary condition index = mx.concatenate( [indices, mx.arange(dim_size[0])] ) # (num_edges + num_nodes,) edge_weights = mx.concatenate([edge_weights, mx.ones(dim_size[0])]) shape = [1] * messages.ndim shape[self.node_dim] = -1 edge_weights = edge_weights.reshape(shape) if self.aggregate_func == "pna": mean = scatter( messages * edge_weights, index, axis=self.node_dim, out_size=dim_size[0], aggr="mean", ) sq_mean = scatter( messages**2 * edge_weights, index, axis=self.node_dim, out_size=dim_size[0], aggr="mean", ) max = scatter( messages * edge_weights, index, axis=self.node_dim, out_size=dim_size[0], aggr="max", ) min = scatter( messages * edge_weights, index, axis=self.node_dim, out_size=dim_size[0], aggr="min", ) std = mx.clip(sq_mean - mean**2, a_min=self.eps, a_max=None).sqrt() features = mx.concatenate( [mean[..., None], max[..., None], min[..., None], std[..., None]], axis=-1, ) features = features.flatten(-2) if self.input_dims == 2: features = features[:, None, :] degree_out = degree(index, dim_size[0])[..., None, None] scale = degree_out.log() scale = scale / scale.mean() scales = mx.concatenate( [ mx.ones_like(scale), scale, 1 / mx.clip(scale, a_min=1e-2, a_max=None), ], axis=-1, ) output = (features[..., None] * scales[:, :, None, :]).flatten(-2) if self.input_dims == 2: output = output.squeeze(1) else: output = scatter( messages * edge_weights, index, axis=self.node_dim, out_size=dim_size[0], aggr=self.aggregate_func, # type: ignore - it's either "add" or "mean" ) return output
[docs] def update_nodes( self, aggregated: mx.array, old: mx.array, ) -> mx.array: # node update: a function of old states (old) and layer's output (aggregated) output = self.linear(mx.concatenate([old, aggregated], axis=-1)) if self.layer_norm: output = self.layer_norm(output) if self.activation: output = self.activation(output) return output.transpose(1, 0, 2) if self.input_dims == 3 else output