@@ -676,9 +676,29 @@ class NNXSimpleFeedForward(nnx.Module):
676676 def __init__ (self , rngs : nnx .Rngs , dim : int , dim_out : Optional [int ] = None , mult : int = 4 , activation_fn : str = "gelu" , dtype : jnp .dtype = jnp .float32 , weights_dtype : jnp .dtype = jnp .float32 , precision : Optional [jax .lax .Precision ] = None ):
677677 inner_dim = int (dim * mult )
678678 dim_out = dim_out if dim_out is not None else dim
679- self .net_0 = nnx .Linear (dim , inner_dim , rngs = rngs , use_bias = True , dtype = dtype , param_dtype = weights_dtype , precision = precision )
679+ self .net_0 = nnx .Linear (
680+ dim ,
681+ inner_dim ,
682+ rngs = rngs ,
683+ use_bias = True ,
684+ dtype = dtype ,
685+ param_dtype = weights_dtype ,
686+ precision = precision ,
687+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "embed" )),
688+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
689+ )
680690 self .act = get_activation (activation_fn )
681- self .net_2 = nnx .Linear (inner_dim , dim_out , rngs = rngs , use_bias = True , dtype = dtype , param_dtype = weights_dtype , precision = precision )
691+ self .net_2 = nnx .Linear (
692+ inner_dim ,
693+ dim_out ,
694+ rngs = rngs ,
695+ use_bias = True ,
696+ dtype = dtype ,
697+ param_dtype = weights_dtype ,
698+ precision = precision ,
699+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "mlp" )),
700+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
701+ )
682702
683703 def __call__ (self , hidden_states : Array ) -> Array :
684704 hidden_states = self .net_0 (hidden_states )
0 commit comments