@@ -359,13 +359,13 @@ def __init__(
359359 # 1. Define Partitioned Initializers (Logical Axes)
360360 # Q, K, V kernels: [in_features (embed), out_features (heads)]
361361 qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "heads" ))
362- # Q, K, V biases: [out_features (embed )]
363- qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ("embed " ,))
362+ # Q, K, V biases: [out_features (heads )]
363+ qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ("heads " ,))
364364
365365 # Out kernel: [in_features (heads), out_features (embed)]
366366 out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("heads" , "embed" ))
367- # Out bias: [out_features (heads )]
368- out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ("heads " ,))
367+ # Out bias: [out_features (embed )]
368+ out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ("embed " ,))
369369
370370 # Norm scales
371371 norm_scale_init = nnx .with_partitioning (nnx .initializers .ones_init (), ("norm" ,))
0 commit comments