@@ -531,7 +531,7 @@ def __init__(
531531
532532 # 4. Rotary Positional Embeddings (RoPE)
533533 self .rope = LTX2AudioVideoRotaryPosEmbed (
534- dim = self .attention_head_dim , # Per head dim
534+ dim = self .inner_dim ,
535535 patch_size = self .patch_size ,
536536 patch_size_t = self .patch_size_t ,
537537 base_num_frames = self .pos_embed_max_pos ,
@@ -547,7 +547,7 @@ def __init__(
547547 dtype = self .dtype ,
548548 )
549549 self .audio_rope = LTX2AudioVideoRotaryPosEmbed (
550- dim = self .audio_attention_head_dim , # Per head dim
550+ dim = self .audio_inner_dim , # Per head dim
551551 patch_size = self .audio_patch_size ,
552552 patch_size_t = self .audio_patch_size_t ,
553553 base_num_frames = self .audio_pos_embed_max_pos ,
@@ -565,7 +565,7 @@ def __init__(
565565
566566 cross_attn_pos_embed_max_pos = max (self .pos_embed_max_pos , self .audio_pos_embed_max_pos )
567567 self .cross_attn_rope = LTX2AudioVideoRotaryPosEmbed (
568- dim = self .attention_head_dim , # Per head dim
568+ dim = self .audio_cross_attention_dim ,
569569 patch_size = self .patch_size ,
570570 patch_size_t = self .patch_size_t ,
571571 base_num_frames = cross_attn_pos_embed_max_pos ,
@@ -580,7 +580,7 @@ def __init__(
580580 dtype = self .dtype ,
581581 )
582582 self .cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed (
583- dim = self .audio_attention_head_dim , # Per head dim
583+ dim = self .audio_cross_attention_dim ,
584584 patch_size = self .audio_patch_size ,
585585 patch_size_t = self .audio_patch_size_t ,
586586 base_num_frames = cross_attn_pos_embed_max_pos ,
0 commit comments