Neural Networks

Neural Networks#

Linear(input_dims, output_dims[, bias])

Linear layer with Xavier Glorot weight inititalization.

MessagePassing([aggr])

Base class for creating Message Passing Neural Networks (MPNNs) [1].

GraphNetworkBlock([node_model, edge_model, ...])

Implements a generic Graph Network block as defined in [1].

SimpleConv([aggr, combine_root_func])

A simple non-trainable message passing layer.

GATConv(node_features_dim, out_features_dim)

Graph Attention Network convolution layer.

GCNConv(node_features_dim, out_features_dim)

Applies a GCN convolution over input node features.

GINConv(mlp[, eps, learn_eps, ...])

Graph Isomorphism Network convolution layer from "How Powerful are Graph Neural Networks?" paper.

SAGEConv(node_features_dim, out_features_dim)

GraphSAGE convolution layer from "Inductive Representation Learning on Large Graphs" paper.

GeneralizedRelationalConv(in_features_dim, ...)

Generalized relational convolution layer from Neural Bellman-Ford Networks: A General Graph Neural Network Framework for Link Prediction paper.

BatchNormalization(num_features[, eps, ...])

Applies batch normalization over a batch of features as described in the Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" paper.

InstanceNormalization(num_features[, eps, ...])

Instance normalization over each individual example in batch of node features as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization paper.

LayerNormalization(num_features[, eps, ...])

Applies layer normalization over each individual example in a batch of features as described in the Layer Normalization paper.

global_add_pool(node_features[, batch_indices])

Sums all node features to obtain a global graph-level representation.

global_max_pool(node_features[, batch_indices])

Takes the feature-wise maximum value along all node features to obtain a global graph-level representation.

global_mean_pool(node_features[, batch_indices])

Takes the feature-wise mean value along all node features to obtain a global graph-level representation.