Source code for mlx_graphs.nn.norm.layer_norm
from typing import Any, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_graphs.utils import degree, scatter
[docs]
class LayerNormalization(nn.Module):
r"""Applies layer normalization over each individual example in a batch
of features as described in the `Layer Normalization <https://arxiv.org/abs/1607.06450>`_
paper.
.. math::
\mathbf{x}^{\prime}_i = \frac{\mathbf{x} -
\textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}}
\odot \gamma + \beta
The mean and standard-deviation are calculated across all nodes and all
node channels separately for each object in a mini-batch.
Args:
in_channels : Size of each input sample.
eps : A value added to the denominator for numerical
stability. (default: :obj:`1e-5`)
affine : If set to :obj:`True`, this module has
learnable affine parameters :math:`\gamma` and :math:`\beta`.
(default: :obj:`True`)
mode : The normalization mode to use for layer
normalization (:obj:`"graph"` or :obj:`"node"`). If :obj:`"graph"`
is used, each graph will be considered as an element to be
normalized. If `"node"` is used, each node will be considered as
an element to be normalized. (default: :obj:`"graph"`)
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
affine: bool = True,
mode: Literal["graph", "node"] = "graph",
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.mode = mode
if affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
self.layernorm = nn.LayerNorm(
dims=self.num_features, eps=self.eps, affine=self.affine
)
[docs]
def __call__(
self,
features: mx.array,
batch: Optional[mx.array] = None,
batch_size: Optional[int] = None,
) -> Any:
if self.mode == "graph":
# perform graph level normalization
if batch is None:
features = features - features.mean()
features = features / (features.var().sqrt() + self.eps)
else:
if batch_size is None:
batch_size = batch.max().item() + 1
batch_index = batch
# try getting degrees of nodes in a batch
norm = mx.clip(degree(batch_index), a_min=1, a_max=None)
norm = norm * (features.shape[-1])
norm = norm.T
mean = (
scatter(
features,
batch_index,
batch_index.max().item() + 1,
aggr="add",
axis=0,
).sum(axis=-1, keepdims=True)
/ norm
)
node_features = features - mx.take(mean, batch_index, axis=0)
variance = scatter(
node_features * node_features,
batch_index,
batch_index.max().item() + 1,
aggr="add",
axis=0,
).sum(axis=-1, keepdims=True)
variance = variance / norm
out = node_features / mx.take(
(variance + self.eps).sqrt(), batch_index, axis=0
)
if self.affine and self.weight is not None and self.bias is not None:
out = out * self.weight + self.bias
return out
elif self.mode == "node":
return self.layernorm(features)
else:
raise ValueError(f"Unknow normalization mode: {self.mode}")