Source code for mlx_graphs.utils.array_ops
from typing import Optional
import mlx.core as mx
[docs]
def broadcast(src: mx.array, other: mx.array, dim: int) -> mx.array:
"""
Make the shape broadcastable between arrays src and other.
May be required in some situations like index broadcasting.
Args:
src: source array to broadcast.
other: other array to match the shape.
Returns:
Array with new broadcastable shape.
"""
if dim < 0:
dim = other.ndim + dim
if src.ndim == 1:
for _ in range(0, dim):
src = mx.expand_dims(src, 0)
for _ in range(src.ndim, other.ndim):
src = mx.expand_dims(src, -1)
src = expand(src, other.shape)
return src
[docs]
def expand(array: mx.array, new_shape: tuple) -> mx.array:
"""
Expand the dimensions of an array to a new shape.
Args:
array: The array to expand.
new_shape: The new shape desired. The new dimensions must be compatible
with the original shape of the array.
Returns:
A view of the array with the new shape.
"""
orig_shape = array.shape
if not all(new_dim >= orig_dim for new_dim, orig_dim in zip(new_shape, orig_shape)):
raise ValueError("New shape must be greater than or equal to original shape")
broadcast_shape = tuple(
max(orig_dim, new_dim) for orig_dim, new_dim in zip(orig_shape, new_shape)
)
return mx.broadcast_to(array, broadcast_shape)
[docs]
def one_hot(labels: mx.array, num_classes: Optional[int] = None) -> mx.array:
"""
Creates one-hot encoded vectors for all elements provided in `labels`.
Given an array of labels [num_elements,], returns an array with shape
[num_elements, num_classes]where each column is an all-zero vector with
a one at the index of the label.
Args:
labels: Array with the labels to transform to one-hot encoded vectors.
num_classes: Number of labels for the one-hot encoding. By default,
``num_classes`` is set to `max_label + 1`.
Returns:
An array of shape [num_elements, num_classes] with one-hot encoded vectors.
"""
if num_classes is None:
num_classes = (labels.max() + 1).item()
shape = (labels.shape[0], num_classes)
one_hot = mx.zeros(shape)
one_hot[mx.arange(shape[0]), labels.squeeze()] = 1
return one_hot
[docs]
def pairwise_distances(x: mx.array, y: mx.array) -> mx.array:
"""
Compute pairwise distances between points in vectors x and y.
Args:
x: array of shape (N, D)
y: array of shape (M, D)
Returns:
Array of shape (N, M)
"""
assert x.shape[1] == y.shape[1], "Input vectors must have the same dimensionality"
# Use broadcasting to compute pairwise differences
expanded_x = mx.expand_dims(x, 1)
distances = mx.linalg.norm(expanded_x - y, axis=-1)
return mx.array(distances.tolist())
[docs]
def index_to_mask(index: mx.array, size: Optional[int] = None) -> mx.array:
"""Converts indices to a mask representation.
Args:
index: Array of indices where the mask will be True.
size: Length of the returned mask array. By default,
the size is set to the maximum index in the indices + 1.
Returns:
A boolean array of length ``size``, with `True` at the given
``index`` indices and `False` elsewhere.
Example:
.. code-block:: python
>>> index = mx.array([1, 2, 3])
>>> index_to_mask(index, size=5)
array([False, True, True, True, False], dtype=bool)
"""
index = index.reshape(-1)
size = index.max().item() + 1 if size is None else size
mask = mx.zeros(size, dtype=mx.bool_)
mask[index] = True
return mask