mlx_graphs.utils.scatter.scatter_softmax

mlx_graphs.utils.scatter.scatter_softmax#

mlx_graphs.utils.scatter.scatter_softmax(values: mlx.core.array, index: mlx.core.array, out_size: int, axis: int = 0) mlx.core.array[source]#

Computes the softmax of values that are scattered along a specified axis, grouped by index.

Parameters:
  • values (array) – Input array containing values to be scattered. These values will undergo a scatter and softmax operation

  • index (array) – Array containing indices that determine the scatter of the ‘values’.

  • out_size (int) – Size of the output array

  • axis (int) – Axis along which to scatter

Return type:

array

Returns:

The resulting array after applying scatter and softmax operations on the input ‘values’.

Example:

src = mx.array([1., 1., 1., 1.])
index = mx.array([0, 0, 1, 2])
num_nodes = src.shape[0]

scatter_softmax(src, index, num_nodes)
>>>  mx.array([0.5, 0.5, 1, 1])