mlx_graphs.nn.BatchNormalization

mlx_graphs.nn.BatchNormalization#

class mlx_graphs.nn.BatchNormalization(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, allow_single_element: bool = False)[source]#

Bases: Module

Applies batch normalization over a batch of features as described in the Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper.

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch.

Parameters:
  • in_channels – Size of each input sample.

  • eps (float) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • track_running_stats (bool) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

  • allow_single_element (bool) – If set to True, batches with only a single element will work as during in evaluation. That is the running mean and variance will be used. Requires track_running_stats=True. (default: False)

__call__(x: mlx.core.array)[source]#

Call self as a function.

Methods