mlx_graphs.data.batch.GraphDataBatch#
- class mlx_graphs.data.batch.GraphDataBatch(graphs: list[GraphData], **kwargs)[source]#
Concatenates multiple GraphData into a single unified GraphDataBatch for efficient computation and parallelization over multiple graphs.
All graphs remain disconnected in the batch, meaning that any pairs of graphs have no nodes or edges in common. GraphDataBatch can be especially used to speed up graph classification tasks, where multiple graphs can easily fit into memory and be processed in parallel.
Example:
from mlx_graphs.data.batch import GraphDataBatch graphs = [ GraphData( edge_index=mx.array([[0, 0, 0], [1, 1, 1]]), node_features=mx.zeros((3, 1)), ), GraphData( edge_index=mx.array([[1, 1, 1], [2, 2, 2]]), node_features=mx.ones((3, 1)), ), GraphData( edge_index=mx.array([[3, 3, 3], [4, 4, 4]]), node_features=mx.ones((3, 1)) * 2, ) ] batch = GraphDataBatch(graphs) >>> GraphDataBatch( edge_index(shape=[2, 9], int32) node_features(shape=[9, 1], float32)) batch.num_graphs >>> 3 batch[1] >>> GraphData( edge_index(shape=[2, 3], int32) node_features(shape=[3, 1], float32)) batch[1:] >>> [ GraphData( edge_index(shape=[2, 3], int32) node_features(shape=[3, 1], float32)), GraphData( edge_index(shape=[2, 3], int32) node_features(shape=[3, 1], float32)) ]
Methods
__init__
(graphs, **kwargs)has_isolated_nodes
()Returns a boolean of whether the graph has isolated nodes or not (i.e., nodes that don't have a link to any other nodes)
has_self_loops
()Returns a boolean of whether the graph contains self loops.
is_directed
()Returns a boolean of whether the graph is directed or not.
is_undirected
()Returns a boolean of whether the graph is undirected or not.
to_dict
()Converts the Data object to a dictionary.
Attributes
batch_indices
Mask indicating for each node its corresponding batch index.
num_edge_classes
Returns the number of edge classes in the current graph.
num_edge_features
Returns the number of edge features.
num_edges
Number of edges in the graph
num_graph_classes
Returns the number of graph classes in the current graph.
num_graph_features
Returns the number of graph features.
num_graphs
Number of graphs in the batch.
num_node_classes
Returns the number of node classes in the current graph.
num_node_features
Returns the number of node features.
num_nodes
Number of nodes in the graph.