mlx_graphs.nn.InstanceNormalization#
- class mlx_graphs.nn.InstanceNormalization(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = False)[source]#
Bases:
Module
Instance normalization over each individual example in batch of node features as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.
- Parameters:
in_channels – Size of each input sample.
eps (
float
) – A value added to the denominator for numerical stability. (default:1e-5
)momentum (
float
) – The value used for the running mean and running variance computation. (default:0.1
)affine (
bool
) – If set toTrue
, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:False
)track_running_stats (
bool
) – If set toTrue
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default:False
)
- __call__(node_features: mlx.core.array, batch: mlx.core.array | None = None, batch_size: int | None = None)[source]#
Call self as a function.
Methods