mlx_graphs.nn.global_mean_pool

mlx_graphs.nn.global_mean_pool#

class mlx_graphs.nn.global_mean_pool(node_features: mlx.core.array, batch_indices: mlx.core.array | None = None) mlx.core.array[source]#

Bases:

Takes the feature-wise mean value along all node features to obtain a global graph-level representation.

If batch_indices is provided, applies pooling for each graph in the batch. The returned shape is (1, node_features_dim) if batch_indices is None, otherise the shape is (num_batches, node_features_dim).

Parameters:
  • node_features (array) – Node features array.

  • batch_indices (Optional[array]) – Batch array of shape (node_features.shape[0]), indicating for each node its batch index.

Return type:

array

Returns:

An array with averaged node features for all provided graphs.

__call__(**kwargs)#

Call self as a function.