mlx_graphs.data.batch.GraphDataBatch#
- class mlx_graphs.data.batch.GraphDataBatch(graphs: list[GraphData], **kwargs)[source]#
Concatenates multiple GraphData into a single unified GraphDataBatch for efficient computation and parallelization over multiple graphs.
All graphs remain disconnected in the batch, meaning that any pairs of graphs have no nodes or edges in common. GraphDataBatch can be especially used to speed up graph classification tasks, where multiple graphs can easily fit into memory and be processed in parallel.
Example:
from mlx_graphs.data.batch import GraphDataBatch graphs = [ GraphData( edge_index=mx.array([[0, 0, 0], [1, 1, 1]]), node_features=mx.zeros((3, 1)), ), GraphData( edge_index=mx.array([[1, 1, 1], [2, 2, 2]]), node_features=mx.ones((3, 1)), ), GraphData( edge_index=mx.array([[3, 3, 3], [4, 4, 4]]), node_features=mx.ones((3, 1)) * 2, ) ] batch = GraphDataBatch(graphs) >>> GraphDataBatch( edge_index(shape=[2, 9], int32) node_features(shape=[9, 1], float32)) batch.num_graphs >>> 3 batch[1] >>> GraphData( edge_index(shape=[2, 3], int32) node_features(shape=[3, 1], float32)) batch[1:] >>> [ GraphData( edge_index(shape=[2, 3], int32) node_features(shape=[3, 1], float32)), GraphData( edge_index(shape=[2, 3], int32) node_features(shape=[3, 1], float32)) ]
Methods
__init__(graphs, **kwargs)has_isolated_nodes()Returns a boolean of whether the graph has isolated nodes or not (i.e., nodes that don't have a link to any other nodes)
has_self_loops()Returns a boolean of whether the graph contains self loops.
is_directed()Returns a boolean of whether the graph is directed or not.
is_undirected()Returns a boolean of whether the graph is undirected or not.
to_dict()Converts the Data object to a dictionary.
Attributes
batch_indicesMask indicating for each node its corresponding batch index.
num_edge_classesReturns the number of edge classes in the current graph.
num_edge_featuresReturns the number of edge features.
num_edgesNumber of edges in the graph
num_graph_classesReturns the number of graph classes in the current graph.
num_graph_featuresReturns the number of graph features.
num_graphsNumber of graphs in the batch.
num_node_classesReturns the number of node classes in the current graph.
num_node_featuresReturns the number of node features.
num_nodesNumber of nodes in the graph.