1+ """Copyright 2025 Google LLC
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ """
15+
16+ from typing import Any , Dict , Optional , Tuple , Union
17+ from flax import nnx
18+ import jax
19+ import jax .numpy as jnp
20+ from ... import common_types
21+ from ..attention_flax import NNXAttentionOp
22+
23+ Array = common_types .Array
24+ Mesh = common_types .Mesh
25+ DType = common_types .DType
26+
27+
28+ def apply_rotary_emb (x : Array , freqs : Tuple [Array , Array ]) -> Array :
29+ """
30+ Applies Interleaved RoPE to input x.
31+ Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
32+
33+ Args:
34+ x: Input tensor [B, S, H, D]
35+ freqs: Tuple of (cos, sin), broadcasting to [B, S, 1, D] or [B, S, H, D]
36+ """
37+ cos , sin = freqs
38+
39+ # 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
40+ # This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
41+ x_reshaped = x .reshape (* x .shape [:- 1 ], - 1 , 2 )
42+
43+ # 2. Split into components
44+ # x_real = x[..., 0], x_imag = x[..., 1]
45+ x_real , x_imag = x_reshaped [..., 0 ], x_reshaped [..., 1 ]
46+
47+ # 3. Rotate [-x2, x1]
48+ # Corresponds to "stack((-t2, t1))"
49+ x_rotated = jnp .stack ([- x_imag , x_real ], axis = - 1 ).reshape (* x .shape )
50+
51+ # 4. Apply frequencies (Float32 for stability)
52+ out = x .astype (jnp .float32 ) * cos + x_rotated .astype (jnp .float32 ) * sin
53+
54+ return out .astype (x .dtype )
55+
56+
57+ class LTX2RotaryPosEmbed (nnx .Module ):
58+ """
59+ RoPE implementation that accepts pre-computed position IDs.
60+ Allows flexibility for 3D (Video) vs 1D (Audio/Temporal) usage.
61+ """
62+ def __init__ (self , dim : int , theta : float = 10000.0 ):
63+ self .dim = dim
64+ self .theta = theta
65+
66+ def __call__ (self , ids : Array ) -> Tuple [Array , Array ]:
67+ """
68+ Generates RoPE frequencies.
69+ Args:
70+ ids: [B, S, Num_Axes]
71+ - For Video 3D: Num_Axes=3 (Time, Height, Width)
72+ - For Audio 1D: Num_Axes=1 (Time)
73+ Returns:
74+ cos, sin: [B, S, 1, Dim] (Ready for broadcasting across heads)
75+ """
76+ num_axes = ids .shape [- 1 ]
77+ dim_per_axis = self .dim // num_axes
78+
79+ # Standard RoPE frequencies
80+ freq_indices = jnp .arange (0 , dim_per_axis , 2 , dtype = jnp .float32 )
81+ inv_freq = 1.0 / (self .theta ** (freq_indices / dim_per_axis ))
82+
83+ freqs_list = []
84+ for i in range (num_axes ):
85+ axis_pos = ids [..., i ]
86+ # Outer product: [B, S] x [D_axis/2] -> [B, S, D_axis/2]
87+ freqs = jnp .einsum ('bs,d->bsd' , axis_pos , inv_freq )
88+ freqs_list .append (freqs )
89+
90+ # Concatenate axes -> [B, S, D/2]
91+ emb = jnp .concatenate (freqs_list , axis = - 1 )
92+
93+ cos = jnp .cos (emb )
94+ sin = jnp .sin (emb )
95+
96+ # Repeat for Interleaved RoPE: [c1, c2] -> [c1, c1, c2, c2]
97+ cos = jnp .repeat (cos , 2 , axis = - 1 )
98+ sin = jnp .repeat (sin , 2 , axis = - 1 )
99+
100+ # Add head dim for broadcasting: [B, S, 1, Inner_Dim]
101+ return cos [:, :, None , :], sin [:, :, None , :]
102+
103+
104+ class LTX2Attention (nnx .Module ):
105+ def __init__ (
106+ self ,
107+ query_dim : int ,
108+ heads : int ,
109+ dim_head : int ,
110+ context_dim : Optional [int ] = None ,
111+ dropout : float = 0.0 ,
112+ bias : bool = True , # LTX-2 uses bias=True for projections
113+ out_bias : bool = True ,
114+ rngs : nnx .Rngs = None ,
115+ mesh : Mesh = None ,
116+ eps : float = 1e-6 ,
117+ dtype : DType = jnp .float32 ,
118+ attention_kernel : str = "flash" ,
119+ ):
120+ self .heads = heads
121+ self .dim_head = dim_head
122+ self .inner_dim = dim_head * heads
123+ self .dropout_rate = dropout
124+
125+ # 1. Projections
126+ self .to_q = nnx .Linear (query_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
127+
128+ # Handle Self vs Cross Attention input dims
129+ kv_dim = context_dim if context_dim is not None else query_dim
130+ self .to_k = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
131+ self .to_v = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
132+
133+ # 2. Normalization (Applied to full inner_dim, NOT per-head)
134+ self .norm_q = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = dtype , use_scale = True , rngs = rngs )
135+ self .norm_k = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = dtype , use_scale = True , rngs = rngs )
136+
137+ # 3. Output
138+ self .to_out = nnx .Linear (self .inner_dim , query_dim , use_bias = out_bias , rngs = rngs , dtype = dtype )
139+
140+ if self .dropout_rate > 0 :
141+ self .dropout_layer = nnx .Dropout (self .dropout_rate , rngs = rngs )
142+ else :
143+ self .dropout_layer = None
144+
145+ self .attention_op = NNXAttentionOp (
146+ mesh = mesh ,
147+ attention_kernel = attention_kernel ,
148+ scale = dim_head ** - 0.5 ,
149+ heads = heads ,
150+ dim_head = dim_head ,
151+ dtype = dtype ,
152+ )
153+
154+ def _reshape_rope (self , rope_emb : Tuple [Array , Array ]) -> Tuple [Array , Array ]:
155+ """Reshapes [B, S, 1, InnerDim] -> [B, S, Heads, DimHead] for broadcasting."""
156+ cos , sin = rope_emb
157+ # If tests pass already shaped tensors, return as is
158+ if cos .ndim == 4 and cos .shape [- 2 ] == self .heads and cos .shape [- 1 ] == self .dim_head :
159+ return cos , sin
160+
161+ # Reshape: [B, S, 1, H*D] -> [B, S, H, D]
162+ # We assume the last dimension is InnerDim = Heads * DimHead
163+ new_shape = cos .shape [:- 2 ] + (self .heads , self .dim_head )
164+ return cos .reshape (new_shape ), sin .reshape (new_shape )
165+
166+ def __call__ (
167+ self ,
168+ hidden_states : Array ,
169+ encoder_hidden_states : Optional [Array ] = None ,
170+ attention_mask : Optional [Array ] = None ,
171+ rotary_emb : Optional [Tuple [Array , Array ]] = None ,
172+ k_rotary_emb : Optional [Tuple [Array , Array ]] = None ,
173+ ) -> Array :
174+
175+ # Determine context (Self or Cross)
176+ context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
177+
178+ # 1. Project
179+ query = self .to_q (hidden_states )
180+ key = self .to_k (context )
181+ value = self .to_v (context )
182+
183+ # 2. Norm (Full Inner Dimension)
184+ query = self .norm_q (query )
185+ key = self .norm_k (key )
186+
187+ # 3. Reshape to Heads [B, S, H, D]
188+ query = query .reshape (* query .shape [:- 1 ], self .heads , self .dim_head )
189+ key = key .reshape (* key .shape [:- 1 ], self .heads , self .dim_head )
190+ value = value .reshape (* value .shape [:- 1 ], self .heads , self .dim_head )
191+
192+ # 4. Apply RoPE
193+ if rotary_emb is not None :
194+ # Reshape [1, Inner] -> [H, D]
195+ q_rope = self ._reshape_rope (rotary_emb )
196+ query = apply_rotary_emb (query , q_rope )
197+
198+ # Key RoPE Logic
199+ if k_rotary_emb is not None :
200+ # Explicit Key RoPE (e.g. Cross-Modal)
201+ k_rope = self ._reshape_rope (k_rotary_emb )
202+ key = apply_rotary_emb (key , k_rope )
203+ elif encoder_hidden_states is None :
204+ # Self-Attention: Re-use q_rope
205+ key = apply_rotary_emb (key , q_rope )
206+
207+ # 5. Flatten back for AttentionOp [B, S, H*D]
208+ # NNXAttentionOp expects flattened input for flash kernel
209+ query = query .reshape (* query .shape [:- 2 ], self .inner_dim )
210+ key = key .reshape (* key .shape [:- 2 ], self .inner_dim )
211+ value = value .reshape (* value .shape [:- 2 ], self .inner_dim )
212+
213+ # 6. Attention
214+ attn_output = self .attention_op .apply_attention (
215+ query = query , key = key , value = value , attention_mask = attention_mask
216+ )
217+
218+ # 7. Output Projection
219+ hidden_states = self .to_out (attn_output )
220+
221+ if self .dropout_layer is not None :
222+ hidden_states = self .dropout_layer (hidden_states )
223+
224+ return hidden_states
0 commit comments