Source code for mlx_graphs.transforms.normalize_features
from typing import List, Union
import mlx.core as mx
from mlx_graphs.data import GraphData, GraphDataBatch
from mlx_graphs.transforms import BaseTransform
[docs]
class FeaturesNormalizedTransform(BaseTransform):
def __init__(self, attributes: List[str] = ["node_features"]):
self.attributes = attributes
[docs]
def process(
self, data: Union[GraphData, GraphDataBatch]
) -> Union[GraphData, GraphDataBatch]:
"""Normalizes the attributes given in the attributes to sum up to one
Args:
data: A GraphData object with node/edge features
Returns:
A GraphData object with normalized features (either edges/node
features unless specified in the transform)
Example:
.. code-block:: python
from mlx_graphs.datasets import EllipticBitcoinDataset
from mlx_graphs.transforms import FeaturesNormalizedTransform
dataset = EllipticBitcoinDataset(transform=FeaturesNormalizedTransform())
# All the node features (if present) for this graph are normalized
"""
if isinstance(data, GraphData):
for attribute in self.attributes:
array = getattr(data, attribute)
if array is not None:
if array.size > 0:
array = array - mx.min(array)
sum_val = mx.sum(array, axis=-1, keepdims=True)
new_array = array / mx.clip(sum_val, 1.0, None)
setattr(data, attribute, new_array)
return data