mlx_graphs.utils.topology.get_num_hops#
- mlx_graphs.utils.topology.get_num_hops(model: mlx.nn.Module) int [source]#
Returns the number of hops the model is aggregating information from. This works only for networks based on MessagePassing.
- Parameters:
model (
Module
) – The GNN Model.- Return type:
- Returns:
number of hops the model is aggregating information
Example:
class GNN(nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(4, 16) self.conv2 = GCNConv(16, 16) self.lin = nn.linear(16, 2) def __call__(self, edge_index: mx.array, node_features: mx.array): x = nn.relu(self.conv1(node_features, edge_index)) x = self.conv2(node_features, edge_index) return self.lin(x) get_num_hops(GNN()) # 2