mlx_graphs.nn.LayerNormalization

mlx_graphs.nn.LayerNormalization#

class mlx_graphs.nn.LayerNormalization(num_features: int, eps: float = 1e-05, affine: bool = True, mode: Literal['graph', 'node'] = 'graph')[source]#

Bases: Module

Applies layer normalization over each individual example in a batch of features as described in the Layer Normalization paper.

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch.

Parameters:
  • in_channels – Size of each input sample.

  • eps (float) – A value added to the denominator for numerical stability. (default: 1e-5)

  • affine (bool) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • mode (Literal['graph', 'node']) – The normalization mode to use for layer normalization ("graph" or "node"). If "graph" is used, each graph will be considered as an element to be normalized. If “node” is used, each node will be considered as an element to be normalized. (default: "graph")

__call__(features: mlx.core.array, batch: mlx.core.array | None = None, batch_size: int | None = None) Any[source]#

Call self as a function.

Return type:

Any

Methods