mlx_graphs.nn.global_mean_pool#
- class mlx_graphs.nn.global_mean_pool(node_features: mlx.core.array, batch_indices: mlx.core.array | None = None) mlx.core.array [source]#
Bases:
Takes the feature-wise mean value along all node features to obtain a global graph-level representation.
If batch_indices is provided, applies pooling for each graph in the batch. The returned shape is (1, node_features_dim) if batch_indices is None, otherise the shape is (num_batches, node_features_dim).
- Parameters:
- Return type:
- Returns:
An array with averaged node features for all provided graphs.
- __call__(**kwargs)#
Call self as a function.