Source code for mlx_graphs.nn.linear

import mlx.core as mx
import mlx.nn as nn


[docs] class Linear(nn.Linear): """Linear layer with Xavier Glorot weight inititalization. This Linear class inherits from `mx.nn.Linear`, but uses glorot initialization instead of the default initialization in mlx's Linear. Args: input_dims: Dimensionality of the input features output_dims: Dimensionality of the output features bias: If set to ``False`` then the layer will not use a bias. Default is ``True``. """ def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None: super().__init__(input_dims, output_dims, bias) glorot_init = nn.init.glorot_uniform() self.weight = glorot_init(mx.zeros((output_dims, input_dims))) if bias: self.bias = glorot_init(mx.zeros((output_dims, 1))).flatten()