Source code for mlx_graphs.nn.pooling.global_pooling
from typing import Optional
import mlx.core as mx
from mlx_graphs.utils import scatter
[docs]
def global_add_pool(
node_features: mx.array, batch_indices: Optional[mx.array] = None
) -> mx.array:
"""Sums all node features to obtain a global graph-level representation.
If `batch_indices` is provided, applies pooling for each graph in the batch.
The returned shape is (1, node_features_dim) if `batch_indices` is None, otherise
the shape is (num_batches, node_features_dim).
Args:
node_features: Node features array.
batch_indices: Batch array of shape (node_features.shape[0]),
indicating for each node its batch index.
Returns:
An array with summed node features for all provided graphs.
"""
if batch_indices is None:
return node_features.sum(axis=0, keepdims=True)
out_size = batch_indices.max().item() + 1
return scatter(node_features, batch_indices, out_size=out_size, axis=0, aggr="add")
[docs]
def global_max_pool(
node_features: mx.array, batch_indices: Optional[mx.array] = None
) -> mx.array:
"""Takes the feature-wise maximum value along all node features to obtain
a global graph-level representation.
If `batch_indices` is provided, applies pooling for each graph in the batch.
The returned shape is (1, node_features_dim) if `batch_indices` is None, otherise
the shape is (num_batches, node_features_dim).
Args:
node_features: Node features array.
batch_indices: Batch array of shape (node_features.shape[0]),
indicating for each node its batch index.
Returns:
An array with maximum node features for all provided graphs.
"""
if batch_indices is None:
return node_features.max(axis=0, keepdims=True)
out_size = batch_indices.max().item() + 1
return scatter(node_features, batch_indices, out_size=out_size, axis=0, aggr="max")
[docs]
def global_mean_pool(
node_features: mx.array, batch_indices: Optional[mx.array] = None
) -> mx.array:
"""Takes the feature-wise mean value along all node features to obtain
a global graph-level representation.
If `batch_indices` is provided, applies pooling for each graph in the batch.
The returned shape is (1, node_features_dim) if `batch_indices` is None, otherise
the shape is (num_batches, node_features_dim).
Args:
node_features: Node features array.
batch_indices: Batch array of shape (node_features.shape[0]),
indicating for each node its batch index.
Returns:
An array with averaged node features for all provided graphs.
"""
if batch_indices is None:
return node_features.mean(axis=0, keepdims=True)
out_size = batch_indices.max().item() + 1
return scatter(node_features, batch_indices, out_size=out_size, axis=0, aggr="mean")