@@ -1010,25 +1010,26 @@ def scan_fn(carry, block):
10101010 transform_metadata = {nnx .PARTITION_NAME : "layers" },
10111011 )(carry , self .transformer_blocks )
10121012 else :
1013- for block in self .transformer_blocks :
1014- hidden_states , audio_hidden_states = block (
1015- hidden_states = hidden_states ,
1016- audio_hidden_states = audio_hidden_states ,
1017- encoder_hidden_states = encoder_hidden_states ,
1018- audio_encoder_hidden_states = audio_encoder_hidden_states ,
1019- temb = temb ,
1020- temb_audio = temb_audio ,
1021- temb_ca_scale_shift = video_cross_attn_scale_shift ,
1022- temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
1023- temb_ca_gate = video_cross_attn_a2v_gate ,
1024- temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1025- video_rotary_emb = video_rotary_emb ,
1026- audio_rotary_emb = audio_rotary_emb ,
1027- ca_video_rotary_emb = video_cross_attn_rotary_emb ,
1028- ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1029- encoder_attention_mask = encoder_attention_mask ,
1030- audio_encoder_attention_mask = audio_encoder_attention_mask ,
1031- )
1013+ for i , block in enumerate (self .transformer_blocks ):
1014+ with jax .named_scope (f"Transformer Block { i } " ):
1015+ hidden_states , audio_hidden_states = block (
1016+ hidden_states = hidden_states ,
1017+ audio_hidden_states = audio_hidden_states ,
1018+ encoder_hidden_states = encoder_hidden_states ,
1019+ audio_encoder_hidden_states = audio_encoder_hidden_states ,
1020+ temb = temb ,
1021+ temb_audio = temb_audio ,
1022+ temb_ca_scale_shift = video_cross_attn_scale_shift ,
1023+ temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
1024+ temb_ca_gate = video_cross_attn_a2v_gate ,
1025+ temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1026+ video_rotary_emb = video_rotary_emb ,
1027+ audio_rotary_emb = audio_rotary_emb ,
1028+ ca_video_rotary_emb = video_cross_attn_rotary_emb ,
1029+ ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1030+ encoder_attention_mask = encoder_attention_mask ,
1031+ audio_encoder_attention_mask = audio_encoder_attention_mask ,
1032+ )
10321033
10331034 # 6. Output layers
10341035 scale_shift_values = jnp .expand_dims (self .scale_shift_table , axis = (0 , 1 )) + jnp .expand_dims (embedded_timestep , axis = 2 )
0 commit comments