@@ -971,7 +971,7 @@ def __call__(
971971 # 5. Run transformer blocks
972972 def scan_fn (carry , block ):
973973 hidden_states , audio_hidden_states , rngs_carry = carry
974- with jax .named_scope ("Transformer Block i " ):
974+ with jax .named_scope ("Transformer Layer " ):
975975 hidden_states_out , audio_hidden_states_out = block (
976976 hidden_states = hidden_states ,
977977 audio_hidden_states = audio_hidden_states ,
@@ -1010,25 +1010,8 @@ def scan_fn(carry, block):
10101010 transform_metadata = {nnx .PARTITION_NAME : "layers" },
10111011 )(carry , self .transformer_blocks )
10121012 else :
1013- mlp_rules = nn .logical_to_mesh_axes (("mlp" , "tensor" ))
1014- tensor_rules = nn .logical_to_mesh_axes (("tensor" ,))
1015-
10161013 for i , block in enumerate (self .transformer_blocks ):
10171014 with jax .named_scope (f"Transformer Block { i } " ):
1018- graphdef , state = nnx .split (block )
1019-
1020- def _apply_weight_sharding (path , x ):
1021- path_str = "/" .join (getattr (p , "name" , getattr (p , "key" , str (p ))) for p in path )
1022- if "kernel" in path_str :
1023- if "ff" in path_str :
1024- return jax .lax .with_sharding_constraint (x , mlp_rules )
1025- elif "attn" in path_str :
1026- return jax .lax .with_sharding_constraint (x , tensor_rules )
1027- return x
1028-
1029- state = jax .tree_util .tree_map_with_path (_apply_weight_sharding , state )
1030- nnx .update (block , state )
1031-
10321015 hidden_states , audio_hidden_states = block (
10331016 hidden_states = hidden_states ,
10341017 audio_hidden_states = audio_hidden_states ,
0 commit comments