@@ -888,14 +888,26 @@ def __init__(
888888 if self .added_kv_proj_dim is not None :
889889 self .add_k_proj = nnx .Linear (
890890 self .added_kv_proj_dim , self .inner_dim , rngs = rngs ,
891- dtype = dtype , param_dtype = weights_dtype , precision = precision
891+ dtype = dtype , param_dtype = weights_dtype , precision = precision ,
892+ bias_init = nnx .with_partitioning (
893+ nnx .initializers .zeros ,
894+ ("embed" ,),
895+ ),
892896 )
893897 self .add_v_proj = nnx .Linear (
894898 self .added_kv_proj_dim , self .inner_dim , rngs = rngs ,
895- dtype = dtype , param_dtype = weights_dtype , precision = precision
899+ dtype = dtype , param_dtype = weights_dtype , precision = precision ,
900+ bias_init = nnx .with_partitioning (
901+ nnx .initializers .zeros ,
902+ ("embed" ,),
903+ ),
896904 )
897905 self .norm_added_k = nnx .RMSNorm (
898- num_features = self .inner_dim , rngs = rngs , epsilon = eps , dtype = dtype , param_dtype = weights_dtype
906+ num_features = self .inner_dim , rngs = rngs , epsilon = eps , dtype = dtype , param_dtype = weights_dtype ,
907+ scale_init = nnx .with_partitioning (
908+ nnx .initializers .ones ,
909+ ("norm" ,),
910+ ),
899911 )
900912
901913 def _apply_rope (self , xq : jax .Array , xk : jax .Array , freqs_cis : jax .Array ) -> Tuple [jax .Array , jax .Array ]:
0 commit comments