@@ -378,11 +378,12 @@ def __call__(
378378
379379 norm_hidden_states = norm_hidden_states * (1 + scale_msa ) + shift_msa
380380
381- attn_hidden_states = self .attn1 (
382- hidden_states = norm_hidden_states ,
383- encoder_hidden_states = None ,
384- rotary_emb = video_rotary_emb ,
385- )
381+ with jax .named_scope ("Video Self-Attention" ):
382+ attn_hidden_states = self .attn1 (
383+ hidden_states = norm_hidden_states ,
384+ encoder_hidden_states = None ,
385+ rotary_emb = video_rotary_emb ,
386+ )
386387 hidden_states = hidden_states + attn_hidden_states * gate_msa
387388
388389 # Calculate Audio AdaLN values
@@ -402,11 +403,12 @@ def __call__(
402403
403404 norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa ) + audio_shift_msa
404405
405- attn_audio_hidden_states = self .audio_attn1 (
406- hidden_states = norm_audio_hidden_states ,
407- encoder_hidden_states = None ,
408- rotary_emb = audio_rotary_emb ,
409- )
406+ with jax .named_scope ("Audio Self-Attention" ):
407+ attn_audio_hidden_states = self .audio_attn1 (
408+ hidden_states = norm_audio_hidden_states ,
409+ encoder_hidden_states = None ,
410+ rotary_emb = audio_rotary_emb ,
411+ )
410412 audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
411413
412414 # 2. Video and Audio Cross-Attention with the text embeddings
0 commit comments