mlx_graphs.utils.topology.get_num_hops

Contents

mlx_graphs.utils.topology.get_num_hops#

mlx_graphs.utils.topology.get_num_hops(model: mlx.nn.Module) int[source]#

Returns the number of hops the model is aggregating information from. This works only for networks based on MessagePassing.

Parameters:

model (Module) – The GNN Model.

Return type:

int

Returns:

number of hops the model is aggregating information

Example:

class GNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(4, 16)
        self.conv2 = GCNConv(16, 16)
        self.lin = nn.linear(16, 2)

    def __call__(self, edge_index: mx.array, node_features: mx.array):
        x = nn.relu(self.conv1(node_features, edge_index))
        x = self.conv2(node_features, edge_index)
        return self.lin(x)


get_num_hops(GNN())
# 2