mlx_graphs.nn.GATConv

mlx_graphs.nn.GATConv#

class mlx_graphs.nn.GATConv(node_features_dim: int, out_features_dim: int, heads: int = 1, concat: bool = True, bias: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_features_dim: int | None = None, **kwargs)[source]#

Bases: MessagePassing

Graph Attention Network convolution layer.

Parameters:
  • node_features_dim (int) – Size of input node features

  • out_features_dim (int) – Size of output node embeddings

  • heads (int) – Number of attention heads

  • concat (bool) – Whether to use concat of heads or mean reduction

  • bias (bool) – Whether to use bias in the node projection

  • negative_slope (float) – Slope for the leaky relu

  • dropout (float) – Probability p for dropout

  • edge_features_dim (Optional[int]) – Size of edge features

Example:

conv = GATConv(16, 32, heads=2, concat=True)
edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]])
node_features = mx.random.uniform(low=0, high=1, shape=(5, 16))

h = conv(edge_index, node_features)

>>> h.shape
[5, 64]
__call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array[source]#

Computes the forward pass of GATConv.

Parameters:
  • edge_index (array) – input edge index of shape [2, num_edges]

  • node_features (array) – input node features

  • edge_features (Optional[array]) – edge features

Returns:

computed node embeddings

Return type:

mx.array

Methods

message(src_features, dst_features, ...[, ...])

Computes a message for each edge in the graph following GAT's propagation rule.

message(src_features: mlx.core.array, dst_features: mlx.core.array, alpha_src: mlx.core.array, alpha_dst: mlx.core.array, index: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array[source]#

Computes a message for each edge in the graph following GAT’s propagation rule.

Parameters:
  • src_features (array) – Features of the source nodes.

  • dst_features (array) – Features of the destination nodes (not used in this function but included for compatibility).

  • alpha_src (array) – Precomputed attention values for the source nodes.

  • alpha_dst (array) – Precomputed attention values for the destination nodes.

  • index (array) – 1D array with indices of either src or dst nodes to compute softmax.

  • edge_features (Optional[array]) – Features of the edges in the graph.

Returns:

The computed messages for each edge in the graph.

Return type:

mx.array