Skip to content

Commit 998edf4

Browse files
committed
annotations for audio and video self attn
1 parent dfd0c89 commit 998edf4

1 file changed

Lines changed: 12 additions & 10 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,12 @@ def __call__(
378378

379379
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
380380

381-
attn_hidden_states = self.attn1(
382-
hidden_states=norm_hidden_states,
383-
encoder_hidden_states=None,
384-
rotary_emb=video_rotary_emb,
385-
)
381+
with jax.named_scope("Video Self-Attention"):
382+
attn_hidden_states = self.attn1(
383+
hidden_states=norm_hidden_states,
384+
encoder_hidden_states=None,
385+
rotary_emb=video_rotary_emb,
386+
)
386387
hidden_states = hidden_states + attn_hidden_states * gate_msa
387388

388389
# Calculate Audio AdaLN values
@@ -402,11 +403,12 @@ def __call__(
402403

403404
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
404405

405-
attn_audio_hidden_states = self.audio_attn1(
406-
hidden_states=norm_audio_hidden_states,
407-
encoder_hidden_states=None,
408-
rotary_emb=audio_rotary_emb,
409-
)
406+
with jax.named_scope("Audio Self-Attention"):
407+
attn_audio_hidden_states = self.audio_attn1(
408+
hidden_states=norm_audio_hidden_states,
409+
encoder_hidden_states=None,
410+
rotary_emb=audio_rotary_emb,
411+
)
410412
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
411413

412414
# 2. Video and Audio Cross-Attention with the text embeddings

0 commit comments

Comments
 (0)