@@ -674,8 +674,10 @@ def __init__(
674674 dtype = dtype ,
675675 quant = quant ,
676676 )
677-
678- kernel_axes = ("embed" , "heads" )
677+ # None axes corresponds to the stacked weights across all blocks
678+ # because of the use of nnx.vmap and nnx.scan.
679+ # Dims are [num_blocks, embed, heads]
680+ kernel_axes = (None , "embed" , "heads" )
679681 qkv_init_kernel = nnx .with_partitioning (nnx .initializers .lecun_normal (), kernel_axes )
680682
681683 self .query = nnx .Linear (
@@ -686,7 +688,7 @@ def __init__(
686688 dtype = dtype ,
687689 param_dtype = weights_dtype ,
688690 precision = precision ,
689- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
691+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" ,)),
690692 )
691693
692694 self .key = nnx .Linear (
@@ -697,7 +699,7 @@ def __init__(
697699 dtype = dtype ,
698700 param_dtype = weights_dtype ,
699701 precision = precision ,
700- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
702+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" ,)),
701703 )
702704
703705 self .value = nnx .Linear (
@@ -708,14 +710,14 @@ def __init__(
708710 dtype = dtype ,
709711 param_dtype = weights_dtype ,
710712 precision = precision ,
711- bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
713+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" ,)),
712714 )
713715
714716 self .proj_attn = nnx .Linear (
715717 rngs = rngs ,
716718 in_features = self .inner_dim ,
717719 out_features = self .inner_dim ,
718- kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("heads" , "embed" )),
720+ kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), (None , "heads" , "embed" )),
719721 dtype = dtype ,
720722 param_dtype = weights_dtype ,
721723 precision = precision ,
@@ -729,15 +731,15 @@ def __init__(
729731 rngs = rngs ,
730732 epsilon = eps ,
731733 dtype = dtype ,
732- scale_init = nnx .with_partitioning (nnx .initializers .ones , ("norm" ,)),
734+ scale_init = nnx .with_partitioning (nnx .initializers .ones , (None , "norm" ,)),
733735 param_dtype = weights_dtype ,
734736 )
735737
736738 self .norm_k = nnx .RMSNorm (
737739 num_features = self .inner_dim ,
738740 rngs = rngs ,
739741 dtype = dtype ,
740- scale_init = nnx .with_partitioning (nnx .initializers .ones , ("norm" ,)),
742+ scale_init = nnx .with_partitioning (nnx .initializers .ones , (None , "norm" ,)),
741743 param_dtype = weights_dtype ,
742744 )
743745
0 commit comments