mlx_graphs.nn.GraphNetworkBlock

mlx_graphs.nn.GraphNetworkBlock#

class mlx_graphs.nn.GraphNetworkBlock(node_model: mlx.nn.layers.base.Module | None = None, edge_model: mlx.nn.layers.base.Module | None = None, graph_model: mlx.nn.layers.base.Module | None = None)[source]#

Bases: Module

Implements a generic Graph Network block as defined in [1].

A Graph Network block takes as input a graph with N nodes and E edges and returns a graph with the same topology.

The input graph can have
  • node_features: features associated with each node in the graph, provided

    as an array of size [N, F_N]

  • edge_features: features associated with each edge in the graph, provided

    as an array of size [E, F_E]

  • graph_features: features associated to the graph itself, of size [F_U]

The topology of the graph is specified as an edge_index, an array of size [2, E], containing the source and destination nodes of each edge as column vectors. A Graph Network block is initialized by providing node, edge and global models (all optional), that are used to update node, edge and global features (if present). Depending on which models are provided and how they are implemented, the Graph Network block acts as a flexible meta-layer that can be used to implement other architectures, like message-passing networks, deep sets, relation networks and more (see [1]).

For a usage example see here.

Parameters:
  • node_model (Optional[Module]) – a callable Module which updates a graph’s node features

  • edge_model (Optional[Module]) – a callable Module which updates a graph’s edge features

  • graph_model (Optional[Module]) – a callable Module which updates a graph’s global features

References

[1] Battaglia et al. Relational Inductive Biases, Deep Learning, and Graph Networks.

__call__(edge_index: mlx.core.array, node_features: mlx.core.array | None = None, edge_features: mlx.core.array | None = None, graph_features: mlx.core.array | None = None) tuple[mlx.core.array | None, mlx.core.array | None, mlx.core.array | None][source]#

Forward pass of the Graph Network block

Parameters:
  • edge_index (array) – array of size [2, E], where each column contains the source and destination nodes of an edge.

  • node_features (Optional[array]) – features associated with each node in the graph, provided as an array of size [N, F_N]

  • edge_features (Optional[array]) – features associated with each edge in the graph, provided as an array of size [E, F_E]

  • graph_features (Optional[array]) – features associated to the graph itself, of size [F_U]

Return type:

tuple[Optional[array], Optional[array], Optional[array]]

Returns:

The tuple of updated node, edge and global attributes.

Methods