mlx_graphs.utils.scatter.scatter_mean

Contents

mlx_graphs.utils.scatter.scatter_mean#

mlx_graphs.utils.scatter.scatter_mean(values: mlx.core.array, index: mlx.core.array, out_size: int, axis: int = 0) mlx.core.array[source]#

Computes the mean of values that are scattered along a specified axis, grouped by index.

Parameters:
  • values (array) – Input array containing values to be scattered. These values will undergo a scatter and mean operation.

  • index (array) – Array containing indices that determine the scatter of the values.

  • out_size (int) – Size of the output array.

  • axis (int) – Axis along which to scatter.

Return type:

array

Returns:

An array containing mean of values grouped by index.