mlx_graphs.nn.GeneralizedRelationalConv#
- class mlx_graphs.nn.GeneralizedRelationalConv(in_features_dim: int, out_features_dim: int, num_relations: int, message_func: Literal['transe', 'distmult', 'rotate'] = 'distmult', aggregate_func: Literal['add', 'mean', 'pna'] = 'add', layer_norm: bool = True, activation: str = 'relu', dependent: bool = False, node_dim: int = 0, **kwargs)[source]#
Bases:
MessagePassing
Generalized relational convolution layer from Neural Bellman-Ford Networks: A General Graph Neural Network Framework for Link Prediction paper.
Adopted from the PyG version from here.
Part of the Neural Bellman-Ford networks (NBFNet) holding state-of-the-art in KG completion. Works with multi-relational graphs where edge types are stored in edge_labels. The message function composes node and relation vectors in three possible ways. The expected behavior is to work with “labeling trick” graphs where one node in the graph is labeled with a query vector, while rest are zeros. Message passing is then done separately for each data point in the batch. The input shape is expected to be [batch_size, num_nodes, input_dim]
Alternatively, the layer can work as a standard relational conv with shapes [num_nodes, input_dim].
Note that this implementation materializes all edge messages and is O(E). The complexity can be further reduced by adopting the O(V) rspmm C++ kernel from the NBFNet-PyG repo to the MLX framework (not implemented here).
- Parameters:
in_features_dim (
int
) – input feature dimension (same for node and edge features)out_features_dim (
int
) – output node feature dimensionnum_relations (
int
) – number of unique relations in the graphmessage_func (
Literal
['transe'
,'distmult'
,'rotate'
]) – “transe” (sum), “distmult” (mult), “rotate” (complex rotation). Default:distmult
aggregate_func (
Literal
['add'
,'mean'
,'pna'
]) – “add”, “mean”, or “pna”. Default:add
layer_norm (
bool
) – whether to use layer norm (often crucial to the performance). Default:True
activation (
str
) – non-linearity. Default:relu
dependent (
bool
) – whether to use separate relation embedding matrixFalse
or build relations from the input relationsTrue
node_dim (
int
) – for 3D batches, specified which dimension contains all nodes. Default:0
Example:
import mlx.core as mx from mlx_graphs.nn import GeneralizedRelationalConv input_dim = 16 output_dim = 16 num_relations = 3 conv = GeneralizedRelationalConv(input_dim, output_dim, num_relations) batch_size = 2 edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) edge_types = mx.array([0, 0, 1, 1, 2]) boundary = mx.random.uniform(0, 1, shape=(batch_size, 5, 16)) size = (boundary.shape[1], boundary.shape[1]) layer_input = boundary # initial node features which will be updated h = conv(edge_index, layer_input, edge_types, boundary, size=size) # optional: residual connection if input dim == output dim h = h + layer_input layer_input = h # another conv type where relations are obtained from the additional # query tensor query = mx.random.uniform(0, 1, shape=(batch_size, 16)) conv2 = GeneralizedRelationalConv( input_dim, output_dim, num_relations, dependent=True) h = conv2(edge_index, layer_input, edge_types, boundary, query, size=size)
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_type: mlx.core.array, boundary: mlx.core.array, query: mlx.core.array | None = None, size: tuple[int, int] | None = None, edge_weights: mlx.core.array | None = None, **kwargs) mlx.core.array [source]#
Computes the forward pass of GeneralizedRelationalConv.
- Parameters:
edge_index (
array
) – Input edge index of shape [2, num_edges]node_features (
array
) – Input node features, shape [bs, num_nodes, dim] or [num_nodes, dim]edge_type (
array
) – Input edge types of shape [num_edges,]boundary (
array
) – Initial node feats [bs, num_nodes, dim] or [num_nodes, dim]query (
Optional
[array
]) – Optional input node queries, shape [bs, dim]size (
Optional
[tuple
[int
,int
]]) – a tuple encoding the size of the graph eg (5, 5)edge_weights (
Optional
[array
]) – Edge weights leveraged in message passing. Default:None
- Return type:
- Returns:
The computed node embeddings
Methods
aggregate
(messages, indices, edge_weights, ...)Aggregates the messages using the self.aggr strategy.
message
(src_features, dst_features, ...)Computes messages between connected nodes.
update_nodes
(aggregated, old)Updates the final embeddings given the aggregated messages.
- aggregate(messages: mlx.core.array, indices: mlx.core.array, edge_weights: mlx.core.array, dim_size: tuple[int, int]) mlx.core.array [source]#
Aggregates the messages using the self.aggr strategy.
- message(src_features: mlx.core.array, dst_features: mlx.core.array, relation: mlx.core.array, boundary: mlx.core.array, edge_type: mlx.core.array) mlx.core.array [source]#
Computes messages between connected nodes.
By default, returns the features of source nodes. Optional
edge_weights
can be directly integrated inkwargs
- Parameters:
- Return type:
- update_nodes(aggregated: mlx.core.array, old: mlx.core.array) mlx.core.array [source]#
Updates the final embeddings given the aggregated messages.