Source code for mlx_graphs.utils.array_ops

from typing import Optional

import mlx.core as mx

[docs] def broadcast(src: mx.array, other: mx.array, dim: int) -> mx.array: """ Make the shape broadcastable between arrays src and other. May be required in some situations like index broadcasting. Args: src: source array to broadcast. other: other array to match the shape. Returns: Array with new broadcastable shape. """ if dim < 0: dim = other.ndim + dim if src.ndim == 1: for _ in range(0, dim): src = mx.expand_dims(src, 0) for _ in range(src.ndim, other.ndim): src = mx.expand_dims(src, -1) src = expand(src, other.shape) return src
[docs] def expand(array: mx.array, new_shape: tuple) -> mx.array: """ Expand the dimensions of an array to a new shape. Args: array: The array to expand. new_shape: The new shape desired. The new dimensions must be compatible with the original shape of the array. Returns: A view of the array with the new shape. """ orig_shape = array.shape if not all(new_dim >= orig_dim for new_dim, orig_dim in zip(new_shape, orig_shape)): raise ValueError("New shape must be greater than or equal to original shape") broadcast_shape = tuple( max(orig_dim, new_dim) for orig_dim, new_dim in zip(orig_shape, new_shape) ) return mx.broadcast_to(array, broadcast_shape)
[docs] def one_hot(labels: mx.array, num_classes: Optional[int] = None) -> mx.array: """ Creates one-hot encoded vectors for all elements provided in `labels`. Given an array of labels [num_elements,], returns an array with shape [num_elements, num_classes]where each column is an all-zero vector with a one at the index of the label. Args: labels: Array with the labels to transform to one-hot encoded vectors. num_classes: Number of labels for the one-hot encoding. By default, ``num_classes`` is set to `max_label + 1`. Returns: An array of shape [num_elements, num_classes] with one-hot encoded vectors. """ if num_classes is None: num_classes = (labels.max() + 1).item() shape = (labels.shape[0], num_classes) one_hot = mx.zeros(shape) one_hot[mx.arange(shape[0]), labels.squeeze()] = 1 return one_hot
[docs] def pairwise_distances(x: mx.array, y: mx.array) -> mx.array: """ Compute pairwise distances between points in vectors x and y. Args: x: array of shape (N, D) y: array of shape (M, D) Returns: Array of shape (N, M) """ assert x.shape[1] == y.shape[1], "Input vectors must have the same dimensionality" # Use broadcasting to compute pairwise differences expanded_x = mx.expand_dims(x, 1) distances = mx.linalg.norm(expanded_x - y, axis=-1) return mx.array(distances.tolist())
[docs] def index_to_mask(index: mx.array, size: Optional[int] = None) -> mx.array: """Converts indices to a mask representation. Args: index: Array of indices where the mask will be True. size: Length of the returned mask array. By default, the size is set to the maximum index in the indices + 1. Returns: A boolean array of length ``size``, with `True` at the given ``index`` indices and `False` elsewhere. Example: .. code-block:: python >>> index = mx.array([1, 2, 3]) >>> index_to_mask(index, size=5) array([False, True, True, True, False], dtype=bool) """ index = index.reshape(-1) size = index.max().item() + 1 if size is None else size mask = mx.zeros(size, dtype=mx.bool_) mask[index] = True return mask