mlx_graphs.nn.GATv2Conv#
- class mlx_graphs.nn.GATv2Conv(node_features_dim: int, out_features_dim: int, heads: int = 1, concat: bool = True, bias: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_features_dim: int | None = None, **kwargs)[source]#
Bases:
MessagePassing
Graph Attention Network convolution layer with dynamic attention.
Modification of GATConv based off of “How Attentive are Graph Attention Networks?” paper.
- Parameters:
node_features_dim (
int
) – Size of input node featuresout_features_dim (
int
) – Size of output node embeddingsheads (
int
) – Number of attention headsconcat (
bool
) – Whether to use concat of heads or mean reductionbias (
bool
) – Whether to use bias in the node projectionnegative_slope (
float
) – Slope for the leaky reludropout (
float
) – Probability p for dropout
Example:
conv = GATv2Conv(16, 32, heads=2, concat=True) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) node_features = mx.random.uniform(low=0, high=1, shape=(5, 16)) h = conv(edge_index, node_features) >>> h.shape [5, 64]
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array [source]#
Computes the forward pass of GATConv.
Methods
message
(src_features, dst_features, src, ...)Computes messages between connected nodes.
- message(src_features: mlx.core.array, dst_features: mlx.core.array, src: mlx.core.array, dst: mlx.core.array, index: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array [source]#
Computes messages between connected nodes.
By default, returns the features of source nodes. Optional
edge_weights
can be directly integrated inkwargs
- Parameters:
- Return type: