@@ -708,9 +708,29 @@ class NNXSimpleFeedForward(nnx.Module):
708708 def __init__ (self , rngs : nnx .Rngs , dim : int , dim_out : Optional [int ] = None , mult : int = 4 , activation_fn : str = "gelu" , dtype : jnp .dtype = jnp .float32 , weights_dtype : jnp .dtype = jnp .float32 , precision : Optional [jax .lax .Precision ] = None ):
709709 inner_dim = int (dim * mult )
710710 dim_out = dim_out if dim_out is not None else dim
711- self .net_0 = nnx .Linear (dim , inner_dim , rngs = rngs , use_bias = True , dtype = dtype , param_dtype = weights_dtype , precision = precision )
711+ self .net_0 = nnx .Linear (
712+ dim ,
713+ inner_dim ,
714+ rngs = rngs ,
715+ use_bias = True ,
716+ dtype = dtype ,
717+ param_dtype = weights_dtype ,
718+ precision = precision ,
719+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "embed" )),
720+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
721+ )
712722 self .act = get_activation (activation_fn )
713- self .net_2 = nnx .Linear (inner_dim , dim_out , rngs = rngs , use_bias = True , dtype = dtype , param_dtype = weights_dtype , precision = precision )
723+ self .net_2 = nnx .Linear (
724+ inner_dim ,
725+ dim_out ,
726+ rngs = rngs ,
727+ use_bias = True ,
728+ dtype = dtype ,
729+ param_dtype = weights_dtype ,
730+ precision = precision ,
731+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "mlp" )),
732+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
733+ )
714734
715735 def __call__ (self , hidden_states : Array ) -> Array :
716736 hidden_states = self .net_0 (hidden_states )
0 commit comments