from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_graphs.utils import degree, scatter
[docs]
class InstanceNormalization(nn.Module):
    r"""Instance normalization over each individual example in batch of node features
    as described in the paper `Instance Normalization: The Missing \
    Ingredient for Fast Stylization <https://arxiv.org/abs/1607.08022>`_
    paper.
    .. math::
        \mathbf{x}^{\prime}_i = \frac{\mathbf{x} -
        \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}}
        \odot \gamma + \beta
    The mean and standard-deviation are calculated per-dimension separately for
    each object in a mini-batch.
    Args:
        in_channels : Size of each input sample.
        eps : A value added to the denominator for numerical
            stability. (default: :obj:`1e-5`)
        momentum : The value used for the running mean and
            running variance computation. (default: :obj:`0.1`)
        affine : If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`False`)
        track_running_stats : If set to :obj:`True`, this
            module tracks the running mean and variance, and when set to
            :obj:`False`, this module does not track such statistics and always
            uses instance statistics in both training and eval modes.
            (default: :obj:`False`)
    """
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = False,
    ):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        self.instance_norm = nn.InstanceNorm(num_features, self.eps, self.affine)
        if affine:
            self.weight = mx.ones((num_features,))
            self.bias = mx.zeros((num_features,))
        if self.track_running_stats:
            self.running_mean = mx.zeros((num_features,))
            self.running_var = mx.ones((num_features,))
            self.freeze(keys=["running_mean", "running_var"], recurse=False)
[docs]
    def __call__(
        self,
        node_features: mx.array,
        batch: Optional[mx.array] = None,
        batch_size: Optional[int] = None,
    ):
        if batch is None:
            return self.instance_norm(mx.expand_dims(node_features, axis=0)).squeeze(
                axis=0
            )
        if batch_size is None:
            batch_size = batch.max().item() + 1
        mean = var = unbiased_var = node_features
        if self.training or not self.track_running_stats:
            norm = mx.clip(degree(batch), a_min=1, a_max=None)
            norm = norm.reshape(-1, 1)
            unbiased_norm = mx.clip(norm - 1, a_min=1, a_max=None)
            mean = (
                scatter(
                    node_features,
                    batch,
                    batch.max().item() + 1,
                    aggr="add",
                    axis=0,
                )
                / norm
            )
            node_features = node_features - mx.take(mean, batch, axis=0)
            var = scatter(
                node_features * node_features,
                batch,
                batch.max().item() + 1,
                aggr="add",
                axis=0,
            )
            unbiased_var = var / unbiased_norm
            var = var / norm
            momentum = self.momentum
            if self.track_running_stats and self.running_mean is not None:
                self.running_mean = (
                    1 - momentum
                ) * self.running_mean + momentum * mean.mean(0)
            if self.track_running_stats and self.running_var is not None:
                self.running_var = (
                    1 - momentum
                ) * self.running_var + momentum * unbiased_var.mean(0)
        else:
            if self.track_running_stats and self.running_mean is not None:
                mean = mx.repeat(self.running_mean.reshape(1, -1), batch_size, axis=0)
            if self.track_running_stats and self.running_var is not None:
                var = mx.repeat(self.running_var.reshape(1, -1), batch_size, axis=0)
            node_features = node_features - mx.take(mean, batch, axis=0)
        out = node_features / mx.take((var + self.eps).sqrt(), batch, axis=0)
        if self.affine and self.weight is not None and self.bias is not None:
            out = out * self.weight + self.bias
        return out