mlx_graphs.nn.GATv2Conv

mlx_graphs.nn.GATv2Conv#

class mlx_graphs.nn.GATv2Conv(node_features_dim: int, out_features_dim: int, heads: int = 1, concat: bool = True, bias: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_features_dim: int | None = None, **kwargs)[source]#

Bases: MessagePassing

Graph Attention Network convolution layer with dynamic attention.

Modification of GATConv based off of “How Attentive are Graph Attention Networks?” paper.

Parameters:
  • node_features_dim (int) – Size of input node features

  • out_features_dim (int) – Size of output node embeddings

  • heads (int) – Number of attention heads

  • concat (bool) – Whether to use concat of heads or mean reduction

  • bias (bool) – Whether to use bias in the node projection

  • negative_slope (float) – Slope for the leaky relu

  • dropout (float) – Probability p for dropout

  • edge_features_dim (Optional[int]) – Size of edge features

Example:

conv = GATv2Conv(16, 32, heads=2, concat=True)
edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]])
node_features = mx.random.uniform(low=0, high=1, shape=(5, 16))

h = conv(edge_index, node_features)

>>> h.shape
[5, 64]
__call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array[source]#

Computes the forward pass of GATConv.

Parameters:
  • edge_index (array) – input edge index of shape [2, num_edges]

  • node_features (array) – input node features

  • edge_features (Optional[array]) – edge features

Returns:

computed node embeddings

Return type:

mx.array

Methods

message(src_features, dst_features, src, ...)

Computes messages between connected nodes.

message(src_features: mlx.core.array, dst_features: mlx.core.array, src: mlx.core.array, dst: mlx.core.array, index: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array[source]#

Computes messages between connected nodes.

By default, returns the features of source nodes. Optional edge_weights can be directly integrated in kwargs

Parameters:
  • src_features (array) – Source node embeddings

  • dst_features (array) – Destination node embeddings

  • edge_weights – Array of scalars with shape (num_edges,) or (num_edges, 1) used to weigh neighbor features during aggregation. Default: None

  • **kwargs – Optional args to compute messages

Return type:

array