@@ -31,8 +31,8 @@ def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
3131 Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
3232
3333 Args:
34- x: Input tensor [B, S, H , D]
35- freqs: Tuple of (cos, sin), broadcasting to [B, S, 1, D] or [B, S, H , D]
34+ x: Input tensor [... , D]
35+ freqs: Tuple of (cos, sin), broadcasting to [... , D]
3636 """
3737 cos , sin = freqs
3838
@@ -71,7 +71,7 @@ def __call__(self, ids: Array) -> Tuple[Array, Array]:
7171 - For Video 3D: Num_Axes=3 (Time, Height, Width)
7272 - For Audio 1D: Num_Axes=1 (Time)
7373 Returns:
74- cos, sin: [B, S, 1, Dim] (Ready for broadcasting across heads)
74+ cos, sin: [B, S, Dim]
7575 """
7676 num_axes = ids .shape [- 1 ]
7777 dim_per_axis = self .dim // num_axes
@@ -97,8 +97,7 @@ def __call__(self, ids: Array) -> Tuple[Array, Array]:
9797 cos = jnp .repeat (cos , 2 , axis = - 1 )
9898 sin = jnp .repeat (sin , 2 , axis = - 1 )
9999
100- # Add head dim for broadcasting: [B, S, 1, Inner_Dim]
101- return cos [:, :, None , :], sin [:, :, None , :]
100+ return cos , sin
102101
103102
104103class LTX2Attention (nnx .Module ):
@@ -151,18 +150,6 @@ def __init__(
151150 dtype = dtype ,
152151 )
153152
154- def _reshape_rope (self , rope_emb : Tuple [Array , Array ]) -> Tuple [Array , Array ]:
155- """Reshapes [B, S, 1, InnerDim] -> [B, S, Heads, DimHead] for broadcasting."""
156- cos , sin = rope_emb
157- # If tests pass already shaped tensors, return as is
158- if cos .ndim == 4 and cos .shape [- 2 ] == self .heads and cos .shape [- 1 ] == self .dim_head :
159- return cos , sin
160-
161- # Reshape: [B, S, 1, H*D] -> [B, S, H, D]
162- # We assume the last dimension is InnerDim = Heads * DimHead
163- new_shape = cos .shape [:- 2 ] + (self .heads , self .dim_head )
164- return cos .reshape (new_shape ), sin .reshape (new_shape )
165-
166153 def __call__ (
167154 self ,
168155 hidden_states : Array ,
@@ -184,33 +171,17 @@ def __call__(
184171 query = self .norm_q (query )
185172 key = self .norm_k (key )
186173
187- # 3. Reshape to Heads [B, S, H, D]
188- query = query .reshape (* query .shape [:- 1 ], self .heads , self .dim_head )
189- key = key .reshape (* key .shape [:- 1 ], self .heads , self .dim_head )
190- value = value .reshape (* value .shape [:- 1 ], self .heads , self .dim_head )
191-
192- # 4. Apply RoPE
174+ # 3. Apply RoPE to tensors of shape [B, S, InnerDim]
175+ # Frequencies are shape [B, S, InnerDim]
193176 if rotary_emb is not None :
194- # Reshape [1, Inner] -> [H, D]
195- q_rope = self ._reshape_rope (rotary_emb )
196- query = apply_rotary_emb (query , q_rope )
197-
198- # Key RoPE Logic
177+ query = apply_rotary_emb (query , rotary_emb )
199178 if k_rotary_emb is not None :
200- # Explicit Key RoPE (e.g. Cross-Modal)
201- k_rope = self ._reshape_rope (k_rotary_emb )
202- key = apply_rotary_emb (key , k_rope )
203- elif encoder_hidden_states is None :
204- # Self-Attention: Re-use q_rope
205- key = apply_rotary_emb (key , q_rope )
206-
207- # 5. Flatten back for AttentionOp [B, S, H*D]
208- # NNXAttentionOp expects flattened input for flash kernel
209- query = query .reshape (* query .shape [:- 2 ], self .inner_dim )
210- key = key .reshape (* key .shape [:- 2 ], self .inner_dim )
211- value = value .reshape (* value .shape [:- 2 ], self .inner_dim )
212-
213- # 6. Attention
179+ key = apply_rotary_emb (key , k_rotary_emb )
180+ elif encoder_hidden_states is None :
181+ key = apply_rotary_emb (key , rotary_emb )
182+
183+ # 4. Attention
184+ # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
214185 attn_output = self .attention_op .apply_attention (
215186 query = query , key = key , value = value , attention_mask = attention_mask
216187 )
0 commit comments