@@ -734,7 +734,7 @@ def __init__(
734734 # None axes corresponds to the stacked weights across all blocks
735735 # because of the use of nnx.vmap and nnx.scan.
736736 # Dims are [num_blocks, embed, heads]
737- kernel_axes = ("embed" , "heads " )
737+ kernel_axes = (None , "qkv " )
738738 qkv_init_kernel = nnx .with_partitioning (nnx .initializers .lecun_normal (), kernel_axes )
739739
740740 self .query = nnx .Linear (
@@ -747,7 +747,7 @@ def __init__(
747747 precision = precision ,
748748 bias_init = nnx .with_partitioning (
749749 nnx .initializers .zeros ,
750- ("embed" , ),
750+ ("qkv" ),
751751 ),
752752 )
753753
@@ -761,7 +761,7 @@ def __init__(
761761 precision = precision ,
762762 bias_init = nnx .with_partitioning (
763763 nnx .initializers .zeros ,
764- ("embed " ,),
764+ ("qkv " ,),
765765 ),
766766 )
767767
@@ -775,22 +775,19 @@ def __init__(
775775 precision = precision ,
776776 bias_init = nnx .with_partitioning (
777777 nnx .initializers .zeros ,
778- ("embed " ,),
778+ ("qkv " ,),
779779 ),
780780 )
781781
782782 self .proj_attn = nnx .Linear (
783783 rngs = rngs ,
784784 in_features = self .inner_dim ,
785785 out_features = self .inner_dim ,
786- kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("heads" , "embed" )),
786+ kernel_init = nnx .with_partitioning (
787+ nnx .initializers .lecun_normal (), ("proj_out" , None )),
787788 dtype = dtype ,
788789 param_dtype = weights_dtype ,
789790 precision = precision ,
790- bias_init = nnx .with_partitioning (
791- nnx .initializers .zeros ,
792- ("heads" ,),
793- ),
794791 )
795792
796793 self .drop_out = nnx .Dropout (dropout )
@@ -803,21 +800,13 @@ def __init__(
803800 rngs = rngs ,
804801 epsilon = eps ,
805802 dtype = dtype ,
806- scale_init = nnx .with_partitioning (
807- nnx .initializers .ones ,
808- ("norm" ,),
809- ),
810803 param_dtype = weights_dtype ,
811804 )
812805
813806 self .norm_k = nnx .RMSNorm (
814807 num_features = self .inner_dim ,
815808 rngs = rngs ,
816809 dtype = dtype ,
817- scale_init = nnx .with_partitioning (
818- nnx .initializers .ones ,
819- ("norm" ,),
820- ),
821810 param_dtype = weights_dtype ,
822811 )
823812
@@ -845,8 +834,6 @@ def __call__(
845834 deterministic : bool = True ,
846835 rngs : nnx .Rngs = None ,
847836 ) -> jax .Array :
848- hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
849- encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
850837 dtype = hidden_states .dtype
851838 if encoder_hidden_states is None :
852839 encoder_hidden_states = hidden_states
@@ -855,6 +842,11 @@ def __call__(
855842 key_proj = self .key (encoder_hidden_states )
856843 value_proj = self .value (encoder_hidden_states )
857844
845+ query_proj = jax .lax .with_sharding_constraint (query_proj , PartitionSpec ("data" , ("tensor" , "fsdp" ), None ))
846+ key_proj = jax .lax .with_sharding_constraint (key_proj , PartitionSpec ("data" , ("tensor" , "fsdp" ), None ))
847+ value_proj = jax .lax .with_sharding_constraint (value_proj , PartitionSpec ("data" , ("tensor" , "fsdp" ), None ))
848+
849+
858850 if self .qk_norm :
859851 query_proj = self .norm_q (query_proj )
860852 key_proj = self .norm_k (key_proj )
0 commit comments