mlx_graphs.data.batch.GraphDataBatch

mlx_graphs.data.batch.GraphDataBatch#

class mlx_graphs.data.batch.GraphDataBatch(graphs: list[GraphData], **kwargs)[source]#

Concatenates multiple GraphData into a single unified GraphDataBatch for efficient computation and parallelization over multiple graphs.

All graphs remain disconnected in the batch, meaning that any pairs of graphs have no nodes or edges in common. GraphDataBatch can be especially used to speed up graph classification tasks, where multiple graphs can easily fit into memory and be processed in parallel.

Parameters:

graphs (list[GraphData]) – List of GraphData objects to batch together

Example:

from mlx_graphs.data.batch import GraphDataBatch

graphs = [
    GraphData(
        edge_index=mx.array([[0, 0, 0], [1, 1, 1]]),
        node_features=mx.zeros((3, 1)),
    ),
    GraphData(
        edge_index=mx.array([[1, 1, 1], [2, 2, 2]]),
        node_features=mx.ones((3, 1)),
    ),
    GraphData(
        edge_index=mx.array([[3, 3, 3], [4, 4, 4]]),
        node_features=mx.ones((3, 1)) * 2,
    )
]
batch = GraphDataBatch(graphs)
>>> GraphDataBatch(
    edge_index(shape=[2, 9], int32)
    node_features(shape=[9, 1], float32))

batch.num_graphs
>>> 3

batch[1]
>>> GraphData(
    edge_index(shape=[2, 3], int32)
    node_features(shape=[3, 1], float32))

batch[1:]
>>> [
        GraphData(
            edge_index(shape=[2, 3], int32)
            node_features(shape=[3, 1], float32)),
        GraphData(
            edge_index(shape=[2, 3], int32)
            node_features(shape=[3, 1], float32))
    ]
__init__(graphs: list[GraphData], **kwargs)[source]#

Methods

__init__(graphs, **kwargs)

has_isolated_nodes()

Returns a boolean of whether the graph has isolated nodes or not (i.e., nodes that don't have a link to any other nodes)

has_self_loops()

Returns a boolean of whether the graph contains self loops.

is_directed()

Returns a boolean of whether the graph is directed or not.

is_undirected()

Returns a boolean of whether the graph is undirected or not.

to_dict()

Converts the Data object to a dictionary.

Attributes

batch_indices

Mask indicating for each node its corresponding batch index.

num_edge_classes

Returns the number of edge classes in the current graph.

num_edge_features

Returns the number of edge features.

num_edges

Number of edges in the graph

num_graph_classes

Returns the number of graph classes in the current graph.

num_graph_features

Returns the number of graph features.

num_graphs

Number of graphs in the batch.

num_node_classes

Returns the number of node classes in the current graph.

num_node_features

Returns the number of node features.

num_nodes

Number of nodes in the graph.