Source code for mlx_graphs.data.collate
import mlx.core as mx
import numpy as np
from mlx_graphs.data.data import GraphData
from mlx_graphs.data.utils import validate_list_of_graph_data
[docs]
@validate_list_of_graph_data
def collate(graph_list: list[GraphData]) -> dict:
"""Concatenates attributes of multiple graphs based on the specifications
of each `GraphData`.
By default, concatenates all default array attributes in dim 0
apart from `edge_index` which is concatenated along dim 1.
Each graph remains independent in the final graph by incrementing
the indices in `edge_index` based on the cumsum of previous number
of nodes per graph.
Args:
graph_list: List of `GraphData` objects to collate
Returns:
Dict containing all the attributes of the unified and disconnected big graph as
well as the "private" attributes used behind the hood by the batching.
These private attributes start with an underscore "_" and can be ignore by
the user.
"""
batch_attr_dict = {}
# Pre-compute __inc__ and __cat_dim__ for all attributes outside the loop
attrs_inc_cat_dim = [
(attr, graph_list[0].__inc__(key=attr), graph_list[0].__cat_dim__(key=attr))
for attr in graph_list[0].to_dict()
]
# To store pre-computed cumsum for attributes where __inc__ is True
cumsum_dict = {}
for attr, __inc__, __cat_dim__ in attrs_inc_cat_dim:
attr_list, attr_sizes = [], []
# Compute cumsum outside the inner loop if __inc__ is True
if __inc__ and attr not in cumsum_dict:
num_attr_list = [
getattr(graph, "__inc__")(attr) for graph in graph_list
] # Assuming __num_nodes__ provides the needed value
cumsum_dict[attr] = mx.cumsum(mx.array([0] + num_attr_list))
cumsum = cumsum_dict.get(attr)
for i, graph in enumerate(graph_list):
attr_array = getattr(graph, attr)
if __inc__:
attr_array = attr_array + cumsum[i] # type: ignore
attr_list.append(attr_array)
attr_sizes.append(attr_array.shape[__cat_dim__])
# Concatenate all attributes at once outside the loop
batch_attr_dict[attr] = mx.concatenate(attr_list, axis=__cat_dim__)
batch_attr_dict[f"_size_{attr}"] = mx.array(attr_sizes)
batch_attr_dict[f"_cat_dim_{attr}"] = __cat_dim__
if __inc__:
batch_attr_dict[f"_inc_{attr}"] = True
batch_attr_dict[f"_cumsum_{attr}"] = cumsum
# Special handling for "edge_index" to be optimized with vectorization
if attr == "edge_index":
cumsum = np.array(cumsum)
batch_indices = np.hstack(
[
np.full((cumsum[i + 1] - cumsum[i]).item(), i)
for i in range(len(cumsum) - 1)
]
)
batch_attr_dict["_batch_indices"] = mx.array(batch_indices)
return batch_attr_dict