Source code for mlx_graphs.utils.sorting
import mlx.core as mx
from mlx_graphs.utils.validators import (
validate_edge_index,
validate_edge_index_and_features,
)
[docs]
@validate_edge_index
def sort_edge_index(edge_index: mx.array) -> tuple[mx.array, mx.array]:
"""Sort the edge index.
Args:
edge_index: A [2, num_edges] array representing edge indices,
where the first row contains source indices and the second row contains
target indices.
Returns:
A tuple containing the sorted edge index and the
corresponding sorting indices.
"""
sorted_target_indices = mx.argsort(edge_index[1])
target_sorted_index = edge_index[:, sorted_target_indices]
sorted_source_indices = mx.argsort(target_sorted_index[0])
sorted_edge_index = target_sorted_index[:, sorted_source_indices]
sorting_indices = sorted_target_indices[sorted_source_indices]
return sorted_edge_index, sorting_indices
[docs]
@validate_edge_index_and_features
def sort_edge_index_and_features(
edge_index: mx.array, edge_features: mx.array
) -> tuple[mx.array, mx.array]:
"""Sorts the given edge_index and their corresponding features.
Args:
edge_index: A [2, num_edges] array representing edge indices,
where the first row contains source indices and the second row contains
target indices.
edge_features: An array representing edge features, where each row
corresponds to an edge.
Returns:
A tuple containing the sorted edge index and the
corresponding sorted edge features.
"""
sorted_edge_index, sorting_indices = sort_edge_index(edge_index)
sorted_edge_features = edge_features[sorting_indices]
return sorted_edge_index, sorted_edge_features