1+ """Copyright 2025 Google LLC
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ """
15+
16+ from typing import Any , Dict , Optional , Tuple , Union
17+ from flax import nnx
18+ import jax
19+ import jax .numpy as jnp
20+ from ... import common_types
21+ from ..attention_flax import NNXAttentionOp
22+
23+ Array = common_types .Array
24+ Mesh = common_types .Mesh
25+ DType = common_types .DType
26+
27+
28+ def apply_rotary_emb (x : Array , freqs : Tuple [Array , Array ]) -> Array :
29+ """
30+ Applies Interleaved RoPE to input x.
31+ Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
32+
33+ Args:
34+ x: Input tensor [..., D]
35+ freqs: Tuple of (cos, sin), broadcasting to [..., D]
36+ """
37+ cos , sin = freqs
38+
39+ # 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
40+ # This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
41+ x_reshaped = x .reshape (* x .shape [:- 1 ], - 1 , 2 )
42+
43+ # 2. Split into components
44+ # x_real = x[..., 0], x_imag = x[..., 1]
45+ x_real , x_imag = x_reshaped [..., 0 ], x_reshaped [..., 1 ]
46+
47+ # 3. Rotate [-x2, x1]
48+ # Corresponds to "stack((-t2, t1))"
49+ x_rotated = jnp .stack ([- x_imag , x_real ], axis = - 1 ).reshape (* x .shape )
50+
51+ # 4. Apply frequencies (Float32 for stability)
52+ out = x .astype (jnp .float32 ) * cos + x_rotated .astype (jnp .float32 ) * sin
53+
54+ return out .astype (x .dtype )
55+
56+
57+ class LTX2RotaryPosEmbed (nnx .Module ):
58+ """
59+ RoPE implementation that accepts pre-computed position IDs.
60+ Allows flexibility for 3D (Video) vs 1D (Audio/Temporal) usage.
61+ """
62+ def __init__ (self , dim : int , theta : float = 10000.0 ):
63+ self .dim = dim
64+ self .theta = theta
65+
66+ def __call__ (self , ids : Array ) -> Tuple [Array , Array ]:
67+ """
68+ Generates RoPE frequencies.
69+ Args:
70+ ids: [B, S, Num_Axes]
71+ - For Video 3D: Num_Axes=3 (Time, Height, Width)
72+ - For Audio 1D: Num_Axes=1 (Time)
73+ Returns:
74+ cos, sin: [B, S, Dim]
75+ """
76+ num_axes = ids .shape [- 1 ]
77+ dim_per_axis = self .dim // num_axes
78+
79+ # Standard RoPE frequencies
80+ freq_indices = jnp .arange (0 , dim_per_axis , 2 , dtype = jnp .float32 )
81+ inv_freq = 1.0 / (self .theta ** (freq_indices / dim_per_axis ))
82+
83+ freqs_list = []
84+ for i in range (num_axes ):
85+ axis_pos = ids [..., i ]
86+ # Outer product: [B, S] x [D_axis/2] -> [B, S, D_axis/2]
87+ freqs = jnp .einsum ('bs,d->bsd' , axis_pos , inv_freq )
88+ freqs_list .append (freqs )
89+
90+ # Concatenate axes -> [B, S, D/2]
91+ emb = jnp .concatenate (freqs_list , axis = - 1 )
92+
93+ cos = jnp .cos (emb )
94+ sin = jnp .sin (emb )
95+
96+ # Repeat for Interleaved RoPE: [c1, c2] -> [c1, c1, c2, c2]
97+ cos = jnp .repeat (cos , 2 , axis = - 1 )
98+ sin = jnp .repeat (sin , 2 , axis = - 1 )
99+
100+ return cos , sin
101+
102+
103+ class LTX2Attention (nnx .Module ):
104+ def __init__ (
105+ self ,
106+ query_dim : int ,
107+ heads : int ,
108+ dim_head : int ,
109+ context_dim : Optional [int ] = None ,
110+ dropout : float = 0.0 ,
111+ bias : bool = True , # LTX-2 uses bias=True for projections
112+ out_bias : bool = True ,
113+ rngs : nnx .Rngs = None ,
114+ mesh : Mesh = None ,
115+ eps : float = 1e-6 ,
116+ dtype : DType = jnp .float32 ,
117+ attention_kernel : str = "flash" ,
118+ ):
119+ self .heads = heads
120+ self .dim_head = dim_head
121+ self .inner_dim = dim_head * heads
122+ self .dropout_rate = dropout
123+
124+ # 1. Projections
125+ self .to_q = nnx .Linear (query_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
126+
127+ # Handle Self vs Cross Attention input dims
128+ kv_dim = context_dim if context_dim is not None else query_dim
129+ self .to_k = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
130+ self .to_v = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
131+
132+ # 2. Normalization (Applied to full inner_dim, NOT per-head)
133+ self .norm_q = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = dtype , use_scale = True , rngs = rngs )
134+ self .norm_k = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = dtype , use_scale = True , rngs = rngs )
135+
136+ # 3. Output
137+ self .to_out = nnx .Linear (self .inner_dim , query_dim , use_bias = out_bias , rngs = rngs , dtype = dtype )
138+
139+ if self .dropout_rate > 0 :
140+ self .dropout_layer = nnx .Dropout (self .dropout_rate , rngs = rngs )
141+ else :
142+ self .dropout_layer = None
143+
144+ self .attention_op = NNXAttentionOp (
145+ mesh = mesh ,
146+ attention_kernel = attention_kernel ,
147+ scale = dim_head ** - 0.5 ,
148+ heads = heads ,
149+ dim_head = dim_head ,
150+ dtype = dtype ,
151+ )
152+
153+ def __call__ (
154+ self ,
155+ hidden_states : Array ,
156+ encoder_hidden_states : Optional [Array ] = None ,
157+ attention_mask : Optional [Array ] = None ,
158+ rotary_emb : Optional [Tuple [Array , Array ]] = None ,
159+ k_rotary_emb : Optional [Tuple [Array , Array ]] = None ,
160+ ) -> Array :
161+
162+ # Determine context (Self or Cross)
163+ context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
164+
165+ # 1. Project
166+ query = self .to_q (hidden_states )
167+ key = self .to_k (context )
168+ value = self .to_v (context )
169+
170+ # 2. Norm (Full Inner Dimension)
171+ query = self .norm_q (query )
172+ key = self .norm_k (key )
173+
174+ # 3. Apply RoPE to tensors of shape [B, S, InnerDim]
175+ # Frequencies are shape [B, S, InnerDim]
176+ if rotary_emb is not None :
177+ query = apply_rotary_emb (query , rotary_emb )
178+ if k_rotary_emb is not None :
179+ key = apply_rotary_emb (key , k_rotary_emb )
180+ elif encoder_hidden_states is None :
181+ key = apply_rotary_emb (key , rotary_emb )
182+
183+ # 4. Attention
184+ # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
185+ attn_output = self .attention_op .apply_attention (
186+ query = query , key = key , value = value , attention_mask = attention_mask
187+ )
188+
189+ # 7. Output Projection
190+ hidden_states = self .to_out (attn_output )
191+
192+ if self .dropout_layer is not None :
193+ hidden_states = self .dropout_layer (hidden_states )
194+
195+ return hidden_states
0 commit comments