@@ -1010,12 +1010,24 @@ def scan_fn(carry, block):
10101010 transform_metadata = {nnx .PARTITION_NAME : "layers" },
10111011 )(carry , self .transformer_blocks )
10121012 else :
1013- activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
1013+ mlp_rules = nn .logical_to_mesh_axes (("mlp" , "tensor" ))
1014+ tensor_rules = nn .logical_to_mesh_axes (("tensor" ,))
10141015
10151016 for i , block in enumerate (self .transformer_blocks ):
10161017 with jax .named_scope (f"Transformer Block { i } " ):
1017- hidden_states = jax .lax .with_sharding_constraint (hidden_states , activation_axis_names )
1018- audio_hidden_states = jax .lax .with_sharding_constraint (audio_hidden_states , activation_axis_names )
1018+ graphdef , state = nnx .split (block )
1019+
1020+ def _apply_weight_sharding (path , x ):
1021+ path_str = "/" .join (getattr (p , "name" , getattr (p , "key" , str (p ))) for p in path )
1022+ if "kernel" in path_str :
1023+ if "ff" in path_str :
1024+ return jax .lax .with_sharding_constraint (x , mlp_rules )
1025+ elif "attn" in path_str :
1026+ return jax .lax .with_sharding_constraint (x , tensor_rules )
1027+ return x
1028+
1029+ state = jax .tree_util .tree_map_with_path (_apply_weight_sharding , state )
1030+ nnx .update (block , state )
10191031
10201032 hidden_states , audio_hidden_states = block (
10211033 hidden_states = hidden_states ,
0 commit comments