mlx_graphs.utils.scatter.scatter_softmax#
- mlx_graphs.utils.scatter.scatter_softmax(values: mlx.core.array, index: mlx.core.array, out_size: int, axis: int = 0) mlx.core.array [source]#
Computes the softmax of values that are scattered along a specified axis, grouped by index.
- Parameters:
- Return type:
- Returns:
The resulting array after applying scatter and softmax operations on the input ‘values’.
Example:
src = mx.array([1., 1., 1., 1.]) index = mx.array([0, 0, 1, 2]) num_nodes = src.shape[0] scatter_softmax(src, index, num_nodes) >>> mx.array([0.5, 0.5, 1, 1])