mlx_graphs.nn.InstanceNormalization

mlx_graphs.nn.InstanceNormalization#

class mlx_graphs.nn.InstanceNormalization(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = False)[source]#

Bases: Module

Instance normalization over each individual example in batch of node features as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization paper.

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.

Parameters:
  • in_channels – Size of each input sample.

  • eps (float) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float) – The value used for the running mean and running variance computation. (default: 0.1)

  • affine (bool) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: False)

  • track_running_stats (bool) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default: False)

__call__(node_features: mlx.core.array, batch: mlx.core.array | None = None, batch_size: int | None = None)[source]#

Call self as a function.

Methods