mlx_graphs.nn.global_max_pool#
- class mlx_graphs.nn.global_max_pool(node_features: mlx.core.array, batch_indices: mlx.core.array | None = None) mlx.core.array[source]#
- Bases: - Takes the feature-wise maximum value along all node features to obtain a global graph-level representation. - If batch_indices is provided, applies pooling for each graph in the batch. The returned shape is (1, node_features_dim) if batch_indices is None, otherise the shape is (num_batches, node_features_dim). - Parameters:
- Return type:
- Returns:
- An array with maximum node features for all provided graphs. 
 - __call__(**kwargs)#
- Call self as a function.