@@ -688,7 +688,13 @@ def __init__(
688688 dtype = dtype ,
689689 param_dtype = weights_dtype ,
690690 precision = precision ,
691- bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" ,)),
691+ bias_init = nnx .with_partitioning (
692+ nnx .initializers .zeros ,
693+ (
694+ None ,
695+ "embed" ,
696+ ),
697+ ),
692698 )
693699
694700 self .key = nnx .Linear (
@@ -699,7 +705,13 @@ def __init__(
699705 dtype = dtype ,
700706 param_dtype = weights_dtype ,
701707 precision = precision ,
702- bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" ,)),
708+ bias_init = nnx .with_partitioning (
709+ nnx .initializers .zeros ,
710+ (
711+ None ,
712+ "embed" ,
713+ ),
714+ ),
703715 )
704716
705717 self .value = nnx .Linear (
@@ -710,7 +722,13 @@ def __init__(
710722 dtype = dtype ,
711723 param_dtype = weights_dtype ,
712724 precision = precision ,
713- bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" ,)),
725+ bias_init = nnx .with_partitioning (
726+ nnx .initializers .zeros ,
727+ (
728+ None ,
729+ "embed" ,
730+ ),
731+ ),
714732 )
715733
716734 self .proj_attn = nnx .Linear (
@@ -731,15 +749,27 @@ def __init__(
731749 rngs = rngs ,
732750 epsilon = eps ,
733751 dtype = dtype ,
734- scale_init = nnx .with_partitioning (nnx .initializers .ones , (None , "norm" ,)),
752+ scale_init = nnx .with_partitioning (
753+ nnx .initializers .ones ,
754+ (
755+ None ,
756+ "norm" ,
757+ ),
758+ ),
735759 param_dtype = weights_dtype ,
736760 )
737761
738762 self .norm_k = nnx .RMSNorm (
739763 num_features = self .inner_dim ,
740764 rngs = rngs ,
741765 dtype = dtype ,
742- scale_init = nnx .with_partitioning (nnx .initializers .ones , (None , "norm" ,)),
766+ scale_init = nnx .with_partitioning (
767+ nnx .initializers .ones ,
768+ (
769+ None ,
770+ "norm" ,
771+ ),
772+ ),
743773 param_dtype = weights_dtype ,
744774 )
745775
0 commit comments