@@ -342,12 +342,24 @@ def __init__(
342342 self .dropout_rate = dropout
343343
344344 # 1. Projections
345- self .to_q = nnx .Linear (query_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
345+ self .to_q = nnx .Linear (
346+ query_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype ,
347+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , "heads" )),
348+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("heads" ,)),
349+ )
346350
347351 # Handle Self vs Cross Attention input dims
348352 kv_dim = context_dim if context_dim is not None else query_dim
349- self .to_k = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
350- self .to_v = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
353+ self .to_k = nnx .Linear (
354+ kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype ,
355+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , "heads" )),
356+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("heads" ,)),
357+ )
358+ self .to_v = nnx .Linear (
359+ kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype ,
360+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("embed" , "heads" )),
361+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("heads" ,)),
362+ )
351363
352364 # 2. Normalization (Applied to full inner_dim, NOT per-head)
353365 self .norm_q = nnx .RMSNorm (
@@ -358,7 +370,11 @@ def __init__(
358370 )
359371
360372 # 3. Output
361- self .to_out = nnx .Linear (self .inner_dim , query_dim , use_bias = out_bias , rngs = rngs , dtype = dtype )
373+ self .to_out = nnx .Linear (
374+ self .inner_dim , query_dim , use_bias = out_bias , rngs = rngs , dtype = dtype ,
375+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), ("heads" , "embed" )),
376+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , ("embed" ,)),
377+ )
362378
363379 if self .dropout_rate > 0 :
364380 self .dropout_layer = nnx .Dropout (self .dropout_rate , rngs = rngs )
0 commit comments