mlx_graphs.nn.GATConv#
- class mlx_graphs.nn.GATConv(node_features_dim: int, out_features_dim: int, heads: int = 1, concat: bool = True, bias: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_features_dim: int | None = None, **kwargs)[source]#
Bases:
MessagePassing
Graph Attention Network convolution layer.
- Parameters:
node_features_dim (
int
) – Size of input node featuresout_features_dim (
int
) – Size of output node embeddingsheads (
int
) – Number of attention headsconcat (
bool
) – Whether to use concat of heads or mean reductionbias (
bool
) – Whether to use bias in the node projectionnegative_slope (
float
) – Slope for the leaky reludropout (
float
) – Probability p for dropout
Example:
conv = GATConv(16, 32, heads=2, concat=True) edge_index = mx.array([[0, 1, 2, 3, 4], [0, 0, 1, 1, 3]]) node_features = mx.random.uniform(low=0, high=1, shape=(5, 16)) h = conv(edge_index, node_features) >>> h.shape [5, 64]
- __call__(edge_index: mlx.core.array, node_features: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array [source]#
Computes the forward pass of GATConv.
Methods
message
(src_features, dst_features, ...[, ...])Computes a message for each edge in the graph following GAT's propagation rule.
- message(src_features: mlx.core.array, dst_features: mlx.core.array, alpha_src: mlx.core.array, alpha_dst: mlx.core.array, index: mlx.core.array, edge_features: mlx.core.array | None = None) mlx.core.array [source]#
Computes a message for each edge in the graph following GAT’s propagation rule.
- Parameters:
src_features (
array
) – Features of the source nodes.dst_features (
array
) – Features of the destination nodes (not used in this function but included for compatibility).alpha_src (
array
) – Precomputed attention values for the source nodes.alpha_dst (
array
) – Precomputed attention values for the destination nodes.index (
array
) – 1D array with indices of either src or dst nodes to compute softmax.edge_features (
Optional
[array
]) – Features of the edges in the graph.
- Returns:
The computed messages for each edge in the graph.
- Return type:
mx.array