from typing import Literal, Optional, Union
import mlx.core as mx
import numpy as np
from mlx_graphs.utils import has_isolated_nodes, has_self_loops
from mlx_graphs.utils.topology import is_undirected
[docs]
class GraphData:
"""
Represents a graph data object with optional features and labels.
Args:
edge_index: edge index representing the topology of
the graph, with shape [2, num_edges].
node_features: Array of shape [num_nodes, num_node_features]
containing the features of each node.
edge_features: Array of shape [num_edges, num_edge_features]
containing the features of each edge.
graph_features: Array of shape [num_graph_features]
containing the global features of the graph.
node_labels: Array of shape [num_nodes, num_node_labels]
containing the labels of each node.
edge_labels: Array of shape [num_edges, num_edge_labels]
containing the labels of each edge.
graph_labels: Array of shape [1, num_graph_labels]
containing the global labels of the graph.
**kwargs: Additional keyword arguments to store any other custom attributes.
"""
def __init__(
self,
edge_index: mx.array,
node_features: Optional[mx.array] = None,
edge_features: Optional[mx.array] = None,
graph_features: Optional[mx.array] = None,
node_labels: Optional[mx.array] = None,
edge_labels: Optional[mx.array] = None,
graph_labels: Optional[mx.array] = None,
**kwargs,
):
self.edge_index = edge_index
self.node_features = node_features
self.edge_features = edge_features
self.graph_features = graph_features
self.node_labels = node_labels
self.edge_labels = edge_labels
self.graph_labels = graph_labels
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
strings = []
for k, v in vars(self).items():
if v is not None and not k.startswith("_"):
if isinstance(v, mx.array):
strings.append(
f"{k}(shape={v.shape}, {str(v.dtype).split('.')[-1]})"
)
else:
strings.append(f"{k}={v}")
prefix = "\n\t"
return f"{type(self).__name__}({prefix + prefix.join(strings)})"
[docs]
def to_dict(self) -> dict:
"""Converts the Data object to a dictionary.
Returns:
A dictionary with all attributes of the `GraphData` object.
"""
return {k: v for k, v in self.__dict__.items() if v is not None}
@property
def num_nodes(self) -> int:
"""Number of nodes in the graph."""
if self.node_features is not None:
return self.node_features.shape[0]
else:
# NOTE: This may be slow for large graphs
return np.unique(np.array(self.edge_index, copy=False)).size
@property
def num_edges(self) -> int:
"""Number of edges in the graph"""
return self.edge_index.shape[1]
@property
def num_node_classes(self) -> int:
"""Returns the number of node classes in the current graph."""
return self._num_classes("node")
@property
def num_edge_classes(self) -> int:
"""Returns the number of edge classes in the current graph."""
return self._num_classes("edge")
@property
def num_graph_classes(self) -> int:
"""Returns the number of graph classes in the current graph."""
return self._num_classes("graph")
@property
def num_node_features(self) -> int:
"""Returns the number of node features."""
if self.node_features is None:
return 0
return 1 if self.node_features.ndim == 1 else self.node_features.shape[-1]
@property
def num_edge_features(self) -> int:
"""Returns the number of edge features."""
if self.edge_features is None:
return 0
return 1 if self.edge_features.ndim == 1 else self.edge_features.shape[-1]
@property
def num_graph_features(self) -> int:
"""Returns the number of graph features."""
if self.graph_features is None:
return 0
return 1 if self.graph_features.ndim == 1 else self.graph_features.shape[-1]
def _num_classes(self, task: Literal["node", "edge", "graph"]) -> int:
labels = getattr(self, f"{task}_labels")
if labels is None:
return 0
elif labels.size == labels.shape[0]:
return np.unique(np.array(labels)).size
return labels.shape[-1]
[docs]
def __cat_dim__(self, key: str, *args, **kwargs) -> int:
"""This method can be overriden when batching is used with custom attributes.
Given the name of a custom attribute `key`, returns the dimension where the
concatenation happens during batching.
By default, all attribute names containing "index" will be concatenated on
axis 1, e.g. `edge_index`. Other attributes are concatenated on axis 0,
e.g. node features.
Args:
key: Name of the attribute on which change the default concatenation
dimension while using batching.
Returns:
The dimension where concatenation will happen when batching.
"""
if "index" in key:
return 1
return 0
[docs]
def __inc__(self, key: str, *args, **kwargs) -> Union[int, None]:
"""This method can be overriden when batching is used with custom attributes.
Given the name of a custom attribute `key`, returns an integer indicating the
incrementing value to apply to the elemnts of the attribute.
By default, all attribute names containing "index" will be incremented based on
the number of nodes in previous batches to avoid duplicate nodes in the index,
e.g. `edge_index`. Other attributes are cnot incremented and keep their original
values, e.g. node features.
If incrementation is not used, the return value should be set to `None`.
Args:
key: Name of the attribute on which change the default incrementation
behavior while using batching.
Returns:
Incrementing value for the given attribute or None.
"""
if "index" in key:
return self.num_nodes
return None
[docs]
def has_isolated_nodes(self) -> bool:
"""Returns a boolean of whether the graph has isolated nodes or not
(i.e., nodes that don't have a link to any other nodes)"""
return has_isolated_nodes(self.edge_index, self.num_nodes)
[docs]
def has_self_loops(self) -> bool:
"""Returns a boolean of whether the graph contains self loops."""
return has_self_loops(self.edge_index)
[docs]
def is_undirected(self) -> bool:
"""Returns a boolean of whether the graph is undirected or not."""
return is_undirected(self.edge_index, self.edge_features)
[docs]
def is_directed(self) -> bool:
"""Returns a boolean of whether the graph is directed or not."""
return not self.is_undirected()
[docs]
class HeteroGraphData:
"""
Represents a graph structure with multiple node and edge types
Args:
edge_index_dict: A dictionary mapping edge types
to their corresponding edge indices. The edge indices are
represented as a 2D array of shape `[2, num_edges]`, where the
first row contains the source node indices and the second row
contains the destination node indices.
node_features_dict: A dictionary
mapping node types to their corresponding node feature. Each node
feature has shape `[num_nodes, num_features]`.
edge_features_dict: A dictionary
mapping edge types to their corresponding edge feature matrices.
Each edge feature matrix has shape `[num_edges, num_features]`.
graph_features: A 1D array containing graph-level
features.
node_labels_dict: A dictionary mapping
node types to their corresponding node label arrays.
Each node label array has shape `[num_nodes]`.
edge_labels_dict: A dictionary mapping
edge types to their corresponding edge label arrays.
Each edge label array has shape `[num_edges]`.
edge_labels_index_dict: A dictionary
mapping edge types to their corresponding edge label index arrays.
The edge label indices indicate the edges for which labels are
available.
graph_labels: A 1D array containing graph-level
labels.
**kwargs: Additional keyword arguments to store custom attributes.
Example:
.. code-block:: python
edge_index_dict = {
("user", "rates", "movie"): mx.array([[0, 1], [0, 1]]),
("movie", "rev_rates", "user"): mx.array([[0, 1], [0, 1]]),
}
node_features_dict = {
"user": mx.array([[0.2], [0.8]]),
"movie": mx.array([[0.5], [0.3]]),
}
data = HeteroGraphData(edge_index_dict, node_features_dict)
"""
def __init__(
self,
edge_index_dict: dict[tuple[str, str, str], mx.array],
node_features_dict: Optional[dict[str, mx.array]] = None,
edge_features_dict: Optional[dict[str, mx.array]] = None,
graph_features: Optional[mx.array] = None,
node_labels_dict: Optional[dict[str, mx.array]] = None,
edge_labels_dict: Optional[dict[str, mx.array]] = None,
edge_labels_index_dict: Optional[dict[str, mx.array]] = None,
graph_labels: Optional[mx.array] = None,
**kwargs,
) -> None:
self.edge_index_dict = edge_index_dict
self.node_features_dict = node_features_dict
self.edge_features_dict = edge_features_dict
self.graph_features = graph_features
self.node_labels_dict = node_labels_dict
self.edge_labels_index_dict = edge_labels_index_dict
self.edge_labels_dict = edge_labels_dict
self.graph_labels = graph_labels
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self) -> str:
lines = [f"{type(self).__name__}("]
# Node features
if self.node_features_dict is not None:
node_features_lines = ["node_features("]
for key, value in self.node_features_dict.items():
node_features_lines.append(
f" '{key}': "
+ f"(shape={value.shape}, {str(value.dtype).split('.')[-1]}),"
)
node_features_lines.append(")\n edge_index(")
lines.extend(node_features_lines)
else:
lines.append("\n edge_index(")
# Edge index
edge_index_lines = []
for key, value in self.edge_index_dict.items():
src_node_type, edge_type, dst_node_type = key
edge_index_lines.append(
f" '{src_node_type} -> {edge_type} -> {dst_node_type}':"
+ f"(shape={value.shape}, {str(value.dtype).split('.')[-1]}),"
)
edge_index_lines.append(")\n edge_labels_index_dict(")
lines.extend(edge_index_lines)
# Edge labels index dict
if self.edge_labels_index_dict is not None:
edge_labels_index_lines = []
for key, value in self.edge_labels_index_dict.items():
src_node_type, edge_type, dst_node_type = key
edge_labels_index_lines.append(
f" '({src_node_type}', '{edge_type}', '{dst_node_type}')':"
+ f"(shape={value.shape}, {str(value.dtype).split('.')[-1]}),"
)
edge_labels_index_lines.append(")\n edge_labels_dict(")
lines.extend(edge_labels_index_lines)
else:
lines.append(")\n edge_labels_dict(")
# Edge labels dict
if self.edge_labels_dict is not None:
edge_labels_lines = []
for key, value in self.edge_labels_dict.items():
src_node_type, edge_type, dst_node_type = key
edge_labels_lines.append(
f" '({src_node_type}', '{edge_type}', '{dst_node_type}')':"
+ f"(shape={value.shape}, {str(value.dtype).split('.')[-1]}),"
)
edge_labels_lines.append(")")
lines.extend(edge_labels_lines)
else:
lines.append(")")
# Additional attributes
for k, v in vars(self).items():
if (
v is not None
and not k.startswith("_")
and k
not in [
"node_features_dict",
"edge_index_dict",
"edge_labels_index_dict",
"edge_labels_dict",
]
):
if isinstance(v, dict):
attr_lines = [f"{k}("]
for key, value in v.items():
if isinstance(value, mx.array):
attr_lines.append(
f" '{key}': "
+ f"(shape={value.shape},"
+ f"{str(value.dtype).split('.')[-1]}),"
)
else:
attr_lines.append(f" '{key}': {value},")
attr_lines.append(")")
lines.extend(attr_lines)
elif isinstance(v, mx.array):
lines.append(f"{k}(shape={v.shape}, {str(v.dtype).split('.')[-1]})")
else:
lines.append(f"{k}={v}")
lines.append(")")
return "\n".join(lines)
[docs]
def to_dict(self) -> dict:
"""Converts the Data object to a dictionary.
Returns:
A dictionary with all attributes of the `HeteroGraphData` object.
"""
return {k: v for k, v in self.__dict__.items() if v is not None}
@property
def num_nodes(self) -> dict[str, int]:
num_nodes = {}
if self.node_features_dict is not None:
for node_type, node_features in self.node_features_dict.items():
num_nodes[node_type] = node_features.shape[0]
else:
for edge_type, edge_index in self.edge_index_dict:
src_node_type, _, dst_node_type = edge_type
if src_node_type not in num_nodes:
num_nodes[src_node_type] = np.unique(
np.array(edge_index[0], copy=False)
).size
if dst_node_type not in num_nodes:
num_nodes[dst_node_type] = np.unique(
np.array(edge_index[1], copy=False)
).size
return num_nodes
@property
def num_graph_features(self) -> int:
"""Returns the number of graph features."""
if self.graph_features is None:
return 0
return 1 if self.graph_features.ndim == 1 else self.graph_features.shape[-1]
@property
def num_edges(self) -> dict[str, int]:
"""dictionary of number of edges for each edge type in the graph."""
return {
edge_type: edge_index.shape[1]
for edge_type, edge_index in self.edge_index_dict.items()
}
@property
def num_node_classes(self) -> Union[dict[str, int], None]:
"""
Returns a dictionary of the number of node classes
for each node type in the current graph.
"""
if self.node_features_dict is not None:
return {
node_type: self._num_classes("node", node_type)
for node_type in self.node_features_dict.keys()
}
return None
@property
def num_edge_classes(self) -> dict[str, int]:
"""
Returns a dictionary of the number of edge classes
for each edge type in the current graph.
"""
return {
edge_type: self._num_classes("edge", edge_type)
for edge_type in self.edge_index_dict.keys()
}
@property
def num_node_features(self) -> dict[str, int]:
"""Returns a dictionary of the number of node features for each node type."""
num_node_features_dict = {}
if self.node_features_dict is not None:
for node_type, node_features in self.node_features_dict.items():
num_node_features_dict[node_type] = (
1 if node_features.ndim == 1 else node_features.shape[-1]
)
return num_node_features_dict
@property
def num_edge_features(self) -> dict[str, int]:
"""Returns a dictionary of the number of edge features for each edge type."""
num_edge_features_dict = {}
if self.edge_features_dict is not None:
for edge_type, edge_features in self.edge_features_dict.items():
num_edge_features_dict[edge_type] = (
1 if edge_features.ndim == 1 else edge_features.shape[-1]
)
return num_edge_features_dict
def _num_classes(self, task: Literal["node", "edge"], type_key: str) -> int:
labels_dict = getattr(self, f"{task}_labels_dict")
if labels_dict is None or type_key not in labels_dict:
return 0
labels = labels_dict[type_key]
if labels.size == labels.shape[0]:
return np.unique(np.array(labels)).size
return labels.shape[-1]
[docs]
def has_isolated_nodes(self, node_type: str) -> bool:
"""Returns a boolean of whether the graph has isolated nodes of the
given type
(i.e., nodes that don't have a link to any other nodes)"""
num_nodes = self.num_nodes[node_type]
for edge_type, edge_index in self.edge_index_dict.items():
src_node_type, _, dst_node_type = edge_type
if src_node_type == node_type or dst_node_type == node_type:
if has_isolated_nodes(edge_index, num_nodes):
return True
return False
[docs]
def has_self_loops(self) -> bool:
"""
Returns a boolean of whether the graph
contains self loops for the given edge type.
"""
return all(
has_self_loops(self.edge_index_dict[edge_type])
for edge_type in self.edge_index_dict
)
[docs]
def is_undirected(self) -> bool:
"""
Returns a boolean indicating whether all edge types in the graph are undirected.
"""
return all(
is_undirected(
self.edge_index_dict[edge_type],
(
self.edge_features_dict.get(edge_type, None)
if self.edge_features_dict
else None
),
)
for edge_type in self.edge_index_dict
)
[docs]
def is_directed(self) -> bool:
"""
Returns a boolean of whether all
edges are directed or not.
"""
return all(
not is_undirected(self.edge_index_dict[edge_type])
for edge_type in self.edge_index_dict
)