Source code for mlx_graphs.data.utils

import functools

from .data import GraphData


[docs] def validate_list_of_graph_data(func): """Decorator function to check the validity of a list of `GraphData`.""" @functools.wraps(func) def wrapper(graph_list: list[GraphData], *args, **kwargs): if not isinstance(graph_list, list): raise ValueError(f"Expected list of GraphData, got {type(graph_list)}.") try: expected_attr = set(graph_list[0].to_dict()) except AttributeError: raise ValueError( "Expected list of GraphData. " "Graph at index 0 in the batch is not of type `GraphData`." ) for i, graph in enumerate(graph_list): if not isinstance(graph, GraphData): raise ValueError( "Expected list of GraphData. " f"Graph at index {i} in the batch is not of type `GraphData`." ) graph_attr = set(graph.to_dict()) if graph_attr != expected_attr: raise ValueError( "A graph in the batch has mismatching attributes. " f"Found attributes at graph index {i}: {graph_attr}." ) return func(graph_list, *args, **kwargs) return wrapper