2626
2727
2828def apply_rotary_emb (x : Array , freqs : Tuple [Array , Array ]) -> Array :
29- """Apply rotary embeddings to input x."""
29+ """
30+ Applies Interleaved RoPE to input x.
31+ Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
32+
33+ Args:
34+ x: Input tensor [B, S, H, D]
35+ freqs: Tuple of (cos, sin), broadcasting to [B, S, 1, D] or [B, S, H, D]
36+ """
3037 cos , sin = freqs
31- # x shape: [B, S, H, D]
32- # cos/sin shape: [B, S, 1, D]
3338
34- # Standard interleaved rotation: [-x2, x1]
39+ # 1. Reshape to pair neighbors: [..., D] -> [..., D//2, 2]
40+ # This corresponds to "rearrange(..., (d r) -> ... d r, r=2)"
3541 x_reshaped = x .reshape (* x .shape [:- 1 ], - 1 , 2 )
42+
43+ # 2. Split into components
44+ # x_real = x[..., 0], x_imag = x[..., 1]
3645 x_real , x_imag = x_reshaped [..., 0 ], x_reshaped [..., 1 ]
3746
47+ # 3. Rotate [-x2, x1]
48+ # Corresponds to "stack((-t2, t1))"
3849 x_rotated = jnp .stack ([- x_imag , x_real ], axis = - 1 ).reshape (* x .shape )
3950
51+ # 4. Apply frequencies (Float32 for stability)
4052 out = x .astype (jnp .float32 ) * cos + x_rotated .astype (jnp .float32 ) * sin
53+
4154 return out .astype (x .dtype )
4255
4356
4457class LTX2RotaryPosEmbed (nnx .Module ):
58+ """
59+ RoPE implementation that accepts pre-computed position IDs.
60+ Allows flexibility for 3D (Video) vs 1D (Audio/Temporal) usage.
61+ """
4562 def __init__ (self , dim : int , theta : float = 10000.0 ):
4663 self .dim = dim
4764 self .theta = theta
@@ -51,31 +68,36 @@ def __call__(self, ids: Array) -> Tuple[Array, Array]:
5168 Generates RoPE frequencies.
5269 Args:
5370 ids: [B, S, Num_Axes]
54- - For Video 3D: Num_Axes=3 (T, H, W)
55- - For Audio 1D: Num_Axes=1 (T)
56- - For Temporal-Only: Pass ids[:, :, 0:1] (Slice to keep only Time)
71+ - For Video 3D: Num_Axes=3 (Time, Height, Width)
72+ - For Audio 1D: Num_Axes=1 (Time)
73+ Returns:
74+ cos, sin: [B, S, 1, Dim] (Ready for broadcasting across heads)
5775 """
5876 num_axes = ids .shape [- 1 ]
5977 dim_per_axis = self .dim // num_axes
6078
79+ # Standard RoPE frequencies
6180 freq_indices = jnp .arange (0 , dim_per_axis , 2 , dtype = jnp .float32 )
6281 inv_freq = 1.0 / (self .theta ** (freq_indices / dim_per_axis ))
6382
6483 freqs_list = []
6584 for i in range (num_axes ):
6685 axis_pos = ids [..., i ]
86+ # Outer product: [B, S] x [D_axis/2] -> [B, S, D_axis/2]
6787 freqs = jnp .einsum ('bs,d->bsd' , axis_pos , inv_freq )
6888 freqs_list .append (freqs )
6989
90+ # Concatenate axes -> [B, S, D/2]
7091 emb = jnp .concatenate (freqs_list , axis = - 1 )
7192
7293 cos = jnp .cos (emb )
7394 sin = jnp .sin (emb )
7495
96+ # Repeat for Interleaved RoPE: [c1, c2] -> [c1, c1, c2, c2]
7597 cos = jnp .repeat (cos , 2 , axis = - 1 )
7698 sin = jnp .repeat (sin , 2 , axis = - 1 )
7799
78- # Add head dim: [B, S, 1, D ]
100+ # Add head dim for broadcasting : [B, S, 1, Inner_Dim ]
79101 return cos [:, :, None , :], sin [:, :, None , :]
80102
81103
@@ -87,7 +109,7 @@ def __init__(
87109 dim_head : int ,
88110 context_dim : Optional [int ] = None ,
89111 dropout : float = 0.0 ,
90- bias : bool = True ,
112+ bias : bool = True , # LTX-2 uses bias=True for projections
91113 out_bias : bool = True ,
92114 rngs : nnx .Rngs = None ,
93115 mesh : Mesh = None ,
@@ -100,16 +122,19 @@ def __init__(
100122 self .inner_dim = dim_head * heads
101123 self .dropout_rate = dropout
102124
125+ # 1. Projections
103126 self .to_q = nnx .Linear (query_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
104127
128+ # Handle Self vs Cross Attention input dims
105129 kv_dim = context_dim if context_dim is not None else query_dim
106130 self .to_k = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
107131 self .to_v = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
108132
109- # Norm over full inner_dim (Fix #2 )
133+ # 2. Normalization (Applied to full inner_dim, NOT per-head )
110134 self .norm_q = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = dtype , use_scale = True , rngs = rngs )
111135 self .norm_k = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = dtype , use_scale = True , rngs = rngs )
112136
137+ # 3. Output
113138 self .to_out = nnx .Linear (self .inner_dim , query_dim , use_bias = out_bias , rngs = rngs , dtype = dtype )
114139
115140 if self .dropout_rate > 0 :
@@ -126,6 +151,18 @@ def __init__(
126151 dtype = dtype ,
127152 )
128153
154+ def _reshape_rope (self , rope_emb : Tuple [Array , Array ]) -> Tuple [Array , Array ]:
155+ """Reshapes [B, S, 1, InnerDim] -> [B, S, Heads, DimHead] for broadcasting."""
156+ cos , sin = rope_emb
157+ # If tests pass already shaped tensors, return as is
158+ if cos .ndim == 4 and cos .shape [- 2 ] == self .heads and cos .shape [- 1 ] == self .dim_head :
159+ return cos , sin
160+
161+ # Reshape: [B, S, 1, H*D] -> [B, S, H, D]
162+ # We assume the last dimension is InnerDim = Heads * DimHead
163+ new_shape = cos .shape [:- 2 ] + (self .heads , self .dim_head )
164+ return cos .reshape (new_shape ), sin .reshape (new_shape )
165+
129166 def __call__ (
130167 self ,
131168 hidden_states : Array ,
@@ -135,33 +172,40 @@ def __call__(
135172 k_rotary_emb : Optional [Tuple [Array , Array ]] = None ,
136173 ) -> Array :
137174
175+ # Determine context (Self or Cross)
138176 context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
139177
140178 # 1. Project
141179 query = self .to_q (hidden_states )
142180 key = self .to_k (context )
143181 value = self .to_v (context )
144182
145- # 2. Norm
183+ # 2. Norm (Full Inner Dimension)
146184 query = self .norm_q (query )
147185 key = self .norm_k (key )
148186
149- # 3. Reshape for RoPE [B, S, H, D]
187+ # 3. Reshape to Heads [B, S, H, D]
150188 query = query .reshape (* query .shape [:- 1 ], self .heads , self .dim_head )
151189 key = key .reshape (* key .shape [:- 1 ], self .heads , self .dim_head )
152190 value = value .reshape (* value .shape [:- 1 ], self .heads , self .dim_head )
153191
154192 # 4. Apply RoPE
155193 if rotary_emb is not None :
156- query = apply_rotary_emb (query , rotary_emb )
194+ # Reshape [1, Inner] -> [H, D]
195+ q_rope = self ._reshape_rope (rotary_emb )
196+ query = apply_rotary_emb (query , q_rope )
157197
198+ # Key RoPE Logic
158199 if k_rotary_emb is not None :
159- key = apply_rotary_emb (key , k_rotary_emb )
160- elif encoder_hidden_states is None : # Self-Attention
161- key = apply_rotary_emb (key , rotary_emb )
162-
163- # 5. Flatten back for AttentionOp (Fix #1)
164- # [B, S, H, D] -> [B, S, H*D]
200+ # Explicit Key RoPE (e.g. Cross-Modal)
201+ k_rope = self ._reshape_rope (k_rotary_emb )
202+ key = apply_rotary_emb (key , k_rope )
203+ elif encoder_hidden_states is None :
204+ # Self-Attention: Re-use q_rope
205+ key = apply_rotary_emb (key , q_rope )
206+
207+ # 5. Flatten back for AttentionOp [B, S, H*D]
208+ # NNXAttentionOp expects flattened input for flash kernel
165209 query = query .reshape (* query .shape [:- 2 ], self .inner_dim )
166210 key = key .reshape (* key .shape [:- 2 ], self .inner_dim )
167211 value = value .reshape (* value .shape [:- 2 ], self .inner_dim )
@@ -171,10 +215,10 @@ def __call__(
171215 query = query , key = key , value = value , attention_mask = attention_mask
172216 )
173217
174- # attn_output is already [B, S, H*D], no reshape needed before output proj
218+ # 7. Output Projection
175219 hidden_states = self .to_out (attn_output )
176220
177221 if self .dropout_layer is not None :
178222 hidden_states = self .dropout_layer (hidden_states )
179223
180- return hidden_states
224+ return hidden_states
0 commit comments