mlx_graphs.nn.GraphNetworkBlock#
- class mlx_graphs.nn.GraphNetworkBlock(node_model: mlx.nn.layers.base.Module | None = None, edge_model: mlx.nn.layers.base.Module | None = None, graph_model: mlx.nn.layers.base.Module | None = None)[source]#
Bases:
Module
Implements a generic Graph Network block as defined in [1].
A Graph Network block takes as input a graph with N nodes and E edges and returns a graph with the same topology.
- The input graph can have
node_features
: features associated with each node in the graph, providedas an array of size [N, F_N]
edge_features
: features associated with each edge in the graph, providedas an array of size [E, F_E]
graph_features
: features associated to the graph itself, of size [F_U]
The topology of the graph is specified as an edge_index, an array of size [2, E], containing the source and destination nodes of each edge as column vectors. A Graph Network block is initialized by providing node, edge and global models (all optional), that are used to update node, edge and global features (if present). Depending on which models are provided and how they are implemented, the Graph Network block acts as a flexible
meta-layer
that can be used to implement other architectures, like message-passing networks, deep sets, relation networks and more (see [1]).For a usage example see here.
- Parameters:
References
[1] Battaglia et al. Relational Inductive Biases, Deep Learning, and Graph Networks.
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array | None = None, edge_features: mlx.core.array | None = None, graph_features: mlx.core.array | None = None) tuple[mlx.core.array | None, mlx.core.array | None, mlx.core.array | None] [source]#
Forward pass of the Graph Network block
- Parameters:
edge_index (
array
) – array of size [2, E], where each column contains the source and destination nodes of an edge.node_features (
Optional
[array
]) – features associated with each node in the graph, provided as an array of size [N, F_N]edge_features (
Optional
[array
]) – features associated with each edge in the graph, provided as an array of size [E, F_E]graph_features (
Optional
[array
]) – features associated to the graph itself, of size [F_U]
- Return type:
- Returns:
The tuple of updated node, edge and global attributes.
Methods