Source code for mlx_graphs.utils.validators
import functools
import mlx.core as mx
[docs]
def validate_adjacency_matrix(func):
"""Decorator function to check the validity of an adjacency matrix."""
@functools.wraps(func)
def wrapper(adjacency_matrix, *args, **kwargs):
if adjacency_matrix.ndim != 2:
raise ValueError(
"Adjacency matrix must be two-dimensional",
f"(got {adjacency_matrix.ndim} dimensions)",
)
if not mx.equal(*adjacency_matrix.shape):
raise ValueError(
"Adjacency matrix must be a square matrix",
f"(got {adjacency_matrix.shape} shape)",
)
return func(adjacency_matrix, *args, **kwargs)
return wrapper
[docs]
def validate_edge_index(func):
"""Decorator function to check the validity of an edge_index."""
@functools.wraps(func)
def wrapper(edge_index, *args, **kwargs):
if edge_index.ndim != 2:
raise ValueError(
"edge_index must be 2-dimensional with shape [2, num_edges]",
f"(got {edge_index.ndim} dimensions)",
)
if edge_index.shape[0] != 2:
raise ValueError(
"edge_index must be 2-dimensional with shape [2, num_edges]",
f"(got {edge_index.shape} shape)",
)
return func(edge_index, *args, **kwargs)
return wrapper
[docs]
def validate_edge_index_and_features(func):
"""Decorator function to check the validity of an edge_index and edge_features."""
@functools.wraps(func)
@validate_edge_index
def wrapper(edge_index, edge_features=None, *args, **kwargs):
if edge_features is not None:
if edge_index.shape[1] != edge_features.shape[0]:
raise ValueError(
"edge_features must be 1 per edge ",
f"(got {edge_index.shape[1]} edges",
f"and {edge_features.shape[0]} features)",
)
return func(edge_index, edge_features, *args, **kwargs)
return wrapper