Source code for mlx_graphs.nn.norm.batch_norm

import mlx.core as mx
import mlx.nn as nn


[docs] class BatchNormalization(nn.Module): r"""Applies batch normalization over a batch of features as described in the `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" <https://arxiv.org/abs/1502.03167>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch. Args: in_channels : Size of each input sample. eps : A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum : The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine : If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats : If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) allow_single_element : If set to :obj:`True`, batches with only a single element will work as during in evaluation. That is the running mean and variance will be used. Requires :obj:`track_running_stats=True`. (default: :obj:`False`) """ def __init__( self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, allow_single_element: bool = False, ): super().__init__() if allow_single_element and not track_running_stats: raise ValueError( "'allow_single_element' requires " "'track_running_stats' to be set to `True`" ) self.module = nn.BatchNorm( num_features=num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, ) self.num_features = num_features self.allow_single_element = allow_single_element
[docs] def __call__(self, x: mx.array): if self.allow_single_element and x.shape[0] <= 1: x = (x - self.module.running_mean) * mx.rsqrt( self.module.running_var + self.module.eps ) return self.module.weight * x + self.module.bias return self.module(x)
def __repr__(self): return f"{self.__class__.__name__}({self.module.num_features})"