@@ -717,8 +717,8 @@ def __init__(
717717 dtype = dtype ,
718718 param_dtype = weights_dtype ,
719719 precision = precision ,
720- kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , None )),
721- bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None ,)),
720+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "mlp" )),
721+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp" ,)),
722722 )
723723 self .act = get_activation (activation_fn )
724724 self .net_2 = nnx .Linear (
@@ -729,8 +729,8 @@ def __init__(
729729 dtype = dtype ,
730730 param_dtype = weights_dtype ,
731731 precision = precision ,
732- kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "mlp " )),
733- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("mlp " ,)),
732+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("mlp" , "embed " )),
733+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed " ,)),
734734 )
735735
736736 def __call__ (self , hidden_states : Array ) -> Array :
0 commit comments