mlx_graphs.utils.array_ops.one_hot

Contents

mlx_graphs.utils.array_ops.one_hot#

mlx_graphs.utils.array_ops.one_hot(labels: mlx.core.array, num_classes: int | None = None) mlx.core.array[source]#

Creates one-hot encoded vectors for all elements provided in labels.

Given an array of labels [num_elements,], returns an array with shape [num_elements, num_classes]where each column is an all-zero vector with a one at the index of the label.

Parameters:
  • labels (array) – Array with the labels to transform to one-hot encoded vectors.

  • num_classes (Optional[int]) – Number of labels for the one-hot encoding. By default, num_classes is set to max_label + 1.

Return type:

array

Returns:

An array of shape [num_elements, num_classes] with one-hot encoded vectors.