Skip to content

Commit bf161f9

Browse files
committed
change for enumerating transformer blocks with scan layers set to False
1 parent f14d691 commit bf161f9

1 file changed

Lines changed: 20 additions & 19 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,25 +1010,26 @@ def scan_fn(carry, block):
10101010
transform_metadata={nnx.PARTITION_NAME: "layers"},
10111011
)(carry, self.transformer_blocks)
10121012
else:
1013-
for block in self.transformer_blocks:
1014-
hidden_states, audio_hidden_states = block(
1015-
hidden_states=hidden_states,
1016-
audio_hidden_states=audio_hidden_states,
1017-
encoder_hidden_states=encoder_hidden_states,
1018-
audio_encoder_hidden_states=audio_encoder_hidden_states,
1019-
temb=temb,
1020-
temb_audio=temb_audio,
1021-
temb_ca_scale_shift=video_cross_attn_scale_shift,
1022-
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1023-
temb_ca_gate=video_cross_attn_a2v_gate,
1024-
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1025-
video_rotary_emb=video_rotary_emb,
1026-
audio_rotary_emb=audio_rotary_emb,
1027-
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1028-
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1029-
encoder_attention_mask=encoder_attention_mask,
1030-
audio_encoder_attention_mask=audio_encoder_attention_mask,
1031-
)
1013+
for i, block in enumerate(self.transformer_blocks):
1014+
with jax.named_scope(f"Transformer Block {i}"):
1015+
hidden_states, audio_hidden_states = block(
1016+
hidden_states=hidden_states,
1017+
audio_hidden_states=audio_hidden_states,
1018+
encoder_hidden_states=encoder_hidden_states,
1019+
audio_encoder_hidden_states=audio_encoder_hidden_states,
1020+
temb=temb,
1021+
temb_audio=temb_audio,
1022+
temb_ca_scale_shift=video_cross_attn_scale_shift,
1023+
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1024+
temb_ca_gate=video_cross_attn_a2v_gate,
1025+
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1026+
video_rotary_emb=video_rotary_emb,
1027+
audio_rotary_emb=audio_rotary_emb,
1028+
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1029+
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1030+
encoder_attention_mask=encoder_attention_mask,
1031+
audio_encoder_attention_mask=audio_encoder_attention_mask,
1032+
)
10321033

10331034
# 6. Output layers
10341035
scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2)

0 commit comments

Comments
 (0)