mlx_graphs.nn.InstanceNormalization#
- class mlx_graphs.nn.InstanceNormalization(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = False)[source]#
Bases:
ModuleInstance normalization over each individual example in batch of node features as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.
- Parameters:
in_channels – Size of each input sample.
eps (
float) – A value added to the denominator for numerical stability. (default:1e-5)momentum (
float) – The value used for the running mean and running variance computation. (default:0.1)affine (
bool) – If set toTrue, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:False)track_running_stats (
bool) – If set toTrue, this module tracks the running mean and variance, and when set toFalse, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default:False)
- __call__(node_features: mlx.core.array, batch: mlx.core.array | None = None, batch_size: int | None = None)[source]#
Call self as a function.
Methods