@@ -349,23 +349,37 @@ def __init__(
349349 rope_type : str = "interleaved" ,
350350 flash_block_sizes : BlockSizes = None ,
351351 flash_min_seq_length : int = 4096 ,
352+ qkv_sharding_spec : Optional [tuple ] = None ,
353+ out_sharding_spec : Optional [tuple ] = None ,
354+ out_bias_sharding_spec : Optional [tuple ] = None ,
352355 ):
353356 self .heads = heads
354357 self .rope_type = rope_type
355358 self .dim_head = dim_head
356359 self .inner_dim = dim_head * heads
357360 self .dropout_rate = dropout
358361
362+ # Auto-detect hardware for sharding specs if not overridden
363+ device_kind = jax .devices ()[0 ].device_kind
364+ is_ironwood = "7x" in device_kind
365+
366+ if qkv_sharding_spec is None :
367+ qkv_sharding_spec = (None , "heads" ) if is_ironwood else ("embed" , "heads" )
368+ if out_sharding_spec is None :
369+ out_sharding_spec = ("heads" , None ) if is_ironwood else ("heads" , "embed" )
370+ if out_bias_sharding_spec is None :
371+ out_bias_sharding_spec = (None ,) if is_ironwood else ("embed" ,)
372+
359373 # 1. Define Partitioned Initializers (Logical Axes)
360374 # Q, K, V kernels: [in_features (embed), out_features (heads)]
361- qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ( None , "heads" ) )
375+ qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), qkv_sharding_spec )
362376 # Q, K, V biases: [out_features (heads)]
363377 qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ("heads" ,))
364378
365379 # Out kernel: [in_features (heads), out_features (embed)]
366- out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ( "heads" , None ) )
380+ out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), out_sharding_spec )
367381 # Out bias: [out_features (embed)]
368- out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ( None ,) )
382+ out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), out_bias_sharding_spec )
369383
370384 # Norm scales
371385 norm_scale_init = nnx .with_partitioning (nnx .initializers .ones_init (), ("norm" ,))
0 commit comments