Source code for mlx_graphs.utils.topology
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_graphs.utils.sorting import sort_edge_index, sort_edge_index_and_features
from mlx_graphs.utils.validators import validate_edge_index_and_features
[docs]
@validate_edge_index_and_features
def is_undirected(
edge_index: mx.array, edge_features: Optional[mx.array] = None
) -> bool:
"""
Determines whether a graph is undirected based on the given edge index
and optional edge features.
Args:
edge_index: The edge index of the graph.
edge_features: Edge features associated
with each edge. If provided, the function considers both edge indices
and features for the check.
Returns:
True if the graph is undirected, False otherwise.
"""
# The function checks if the sorted order of source-destination pairs is equal
# to the sorted order of destination-source pairs. If edge features are provided,
# it also checks for equality in their order.
if edge_features is None:
src_dst_sort, _ = sort_edge_index(edge_index)
dst_src_sort, _ = sort_edge_index(mx.stack([edge_index[1], edge_index[0]]))
if mx.array_equal(src_dst_sort, dst_src_sort):
return True
else:
src_dst_sort, src_dst_feat = sort_edge_index_and_features(
edge_index, edge_features
)
dst_src_sort, dst_src_feat = sort_edge_index_and_features(
mx.stack([edge_index[1], edge_index[0]]), edge_features
)
if mx.array_equal(src_dst_sort, dst_src_sort) and mx.array_equal(
src_dst_feat, dst_src_feat
):
return True
return False
[docs]
@validate_edge_index_and_features
def is_directed(edge_index: mx.array, edge_features: Optional[mx.array] = None) -> bool:
"""
Determines whether a graph is directed based on the given edge index
and optional edge features.
Args:
edge_index: The edge index of the graph.
edge_features: Edge features associated
with each edge. If provided, the function considers both edge indices
and features for the check.
Returns:
True if the graph is directed, False otherwise.
"""
return not is_undirected(edge_index, edge_features)
[docs]
def get_num_hops(model: nn.Module) -> int:
"""
Returns the number of hops the model is aggregating information
from. This works only for networks based on `MessagePassing`.
Args:
model: The GNN Model.
Returns:
number of hops the model is aggregating information
Example:
.. code-block:: python
class GNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(4, 16)
self.conv2 = GCNConv(16, 16)
self.lin = nn.linear(16, 2)
def __call__(self, edge_index: mx.array, node_features: mx.array):
x = nn.relu(self.conv1(node_features, edge_index))
x = self.conv2(node_features, edge_index)
return self.lin(x)
get_num_hops(GNN())
# 2
"""
from mlx_graphs.nn import MessagePassing
num_hops = 0
for module in model.modules():
if isinstance(module, MessagePassing):
num_hops += 1
return num_hops