mlx_graphs.utils.scatter.scatter

Contents

mlx_graphs.utils.scatter.scatter#

mlx_graphs.utils.scatter.scatter(values: mlx.core.array, index: mlx.core.array, out_size: int | None = None, aggr: Literal['add', 'max', 'mean', 'softmax', 'min'] = 'add', axis: int = 0) mlx.core.array[source]#

Default function for performing all scattering operations. Scatters values at index in an empty array of out_size elements.

Parameters:
  • values (array) – array with all the values to scatter in the output tensor

  • index (array) – array with index to which scatter the values

  • out_size (Optional[int]) – number of elements in the output array (size of the first dimension). If not provided, uses the number of elements in values

  • aggr (Literal['add', 'max', 'mean', 'softmax', 'min']) – scattering method employed for reduction at index

  • axis (int) – axis on which applying the scattering

Return type:

array

Returns:

Array with out_size elements containing the scattered values at given index following the given aggr reduction method

Example:

src = mx.array([1., 1., 1., 1.])
index = mx.array([0, 0, 1, 2])
num_nodes = src.shape[0]

scatter(src, index, num_nodes, "softmax")
>>>  mx.array([0.5, 0.5, 1, 1])

num_nodes = index.max().item() + 1
scatter(src, index, num_nodes, "add")
>>>  mx.array([2, 1, 1])