mlx_graphs.utils.scatter.scatter#
- mlx_graphs.utils.scatter.scatter(values: mlx.core.array, index: mlx.core.array, out_size: int | None = None, aggr: Literal['add', 'max', 'mean', 'softmax', 'min'] = 'add', axis: int = 0) mlx.core.array [source]#
Default function for performing all scattering operations. Scatters values at index in an empty array of out_size elements.
- Parameters:
values (
array
) – array with all the values to scatter in the output tensorindex (
array
) – array with index to which scatter the valuesout_size (
Optional
[int
]) – number of elements in the output array (size of the first dimension). If not provided, uses the number of elements in valuesaggr (
Literal
['add'
,'max'
,'mean'
,'softmax'
,'min'
]) – scattering method employed for reduction at indexaxis (
int
) – axis on which applying the scattering
- Return type:
- Returns:
Array with out_size elements containing the scattered values at given index following the given aggr reduction method
Example:
src = mx.array([1., 1., 1., 1.]) index = mx.array([0, 0, 1, 2]) num_nodes = src.shape[0] scatter(src, index, num_nodes, "softmax") >>> mx.array([0.5, 0.5, 1, 1]) num_nodes = index.max().item() + 1 scatter(src, index, num_nodes, "add") >>> mx.array([2, 1, 1])