@@ -54,6 +54,85 @@ def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
5454 return out .astype (x .dtype )
5555
5656
57+ def apply_split_rotary_emb (x : Array , freqs : Tuple [Array , Array ]) -> Array :
58+ """
59+ Applies Split RoPE to input x.
60+ Logic matches Diffusers apply_split_rotary_emb.
61+
62+ Args:
63+ x: Input tensor.
64+ If ndim=3 [B, S, D], it will be reshaped to satisfy cos/sin shapes if needed.
65+ freqs: Tuple of (cos, sin).
66+ Expected to be [B, H, S, D//2] if coming from LTX2RotaryPosEmbed(split).
67+ """
68+ cos , sin = freqs
69+
70+ x_dtype = x .dtype
71+ needed_reshape = False
72+ original_shape = x .shape
73+
74+ # Check if we need to reshape x to match cos layout (B, H, S, D//2)
75+ # x typically [B, S, H*D] or [B, S, D]
76+ # cos typically [B, H, S, D//2]
77+
78+ if x .ndim != 4 and cos .ndim == 4 :
79+ # x is [B, S, Dim]
80+ # cos is [B, H, S, R]
81+ b = x .shape [0 ]
82+ h , s , r = cos .shape [1 ], cos .shape [2 ], cos .shape [3 ]
83+
84+ # Verify dimensions roughly match
85+ # D (dim per head) = R * 2
86+ # Dim = H * D = H * 2 * R
87+
88+ # reshape x to [B, S, H, 2*R] -> transpose to [B, H, S, 2*R]
89+ x = x .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
90+ needed_reshape = True
91+
92+ # Now x should be [..., 2*R] i.e. [B, H, S, 2*R] considering the logic below
93+
94+ last_dim = x .shape [- 1 ]
95+ r = last_dim // 2
96+
97+ # Reshape last dim to (2, r)
98+ # [..., 2*R] -> [..., 2, R]
99+ split_x = x .reshape (* x .shape [:- 1 ], 2 , r )
100+
101+ # Split into first and second half
102+ first_x = split_x [..., 0 , :] # [..., R]
103+ second_x = split_x [..., 1 , :] # [..., R]
104+
105+ # Broadcast cos/sin: [B, H, S, R] -> [B, H, S, 1, R]
106+ cos_u = jnp .expand_dims (cos , axis = - 2 )
107+ sin_u = jnp .expand_dims (sin , axis = - 2 )
108+
109+ # out = split_x * cos_u
110+ # This applies cos to both halves
111+ out = split_x * cos_u
112+
113+ # Modifications
114+ # first_out = x1*cos - x2*sin
115+ # second_out = x2*cos + x1*sin
116+
117+ # Apply updates
118+ # We construct result manually to avoid in-place ops
119+ out_first = out [..., 0 , :] - second_x * sin_u .squeeze (- 2 )
120+ out_second = out [..., 1 , :] + first_x * sin_u .squeeze (- 2 )
121+
122+ # Stack back: [..., 2, R]
123+ out = jnp .stack ([out_first , out_second ], axis = - 2 )
124+
125+ # Flatten back last dim: [..., 2*R]
126+ out = out .reshape (* out .shape [:- 2 ], last_dim )
127+
128+ if needed_reshape :
129+ # [B, H, S, D] -> [B, S, H, D] -> [B, S, H*D]
130+ out = out .transpose (0 , 2 , 1 , 3 ).reshape (original_shape )
131+
132+ return out .astype (x_dtype )
133+
134+
135+
57136class LTX2RotaryPosEmbed (nnx .Module ):
58137 """
59138 Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model.
@@ -131,26 +210,13 @@ def prepare_video_coords(
131210 latent_coords = jnp .expand_dims (latent_coords , 0 ) # [1, num_patches, 3, 2]
132211 latent_coords = jnp .tile (latent_coords , (batch_size , 1 , 1 , 1 )) # [B, num_patches, 3, 2]
133212
134- # Transpose to match desired shape [B, 3, num_patches, 2] if needed,
135- # BUT Diffusers returns [B, 3, num_patches, 2] from flatten(1,3) on [3, N_F, N_H, N_W, 2]??
136- # Diffusers:
137- # latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2]
138- # latent_coords = latent_coords.flatten(1, 3) # [3, num_patches, 2]
139- # latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) # [B, 3, num_patches, 2]
140- # My JAX above:
141- # latent_coords = latent_coords.reshape(-1, 3, 2) was wrong relative to Diffusers shape
142-
143- # Correct JAX implementation matching Diffusers:
144213 latent_coords = jnp .stack ([grid , patch_ends ], axis = - 1 ) # [3, N_F, N_H, N_W, 2]
145214 latent_coords = latent_coords .reshape (3 , - 1 , 2 ) # [3, num_patches, 2]
146215 latent_coords = jnp .expand_dims (latent_coords , 0 ) # [1, 3, num_patches, 2]
147216 latent_coords = jnp .tile (latent_coords , (batch_size , 1 , 1 , 1 )) # [B, 3, num_patches, 2]
148217
149218 # 3. Calculate pixel space coords
150219 scale_tensor = jnp .array (self .scale_factors , dtype = latent_coords .dtype )
151- # Broadcast scale factors: [1, 3, 1, 1] matches [B, 3, num, 2] logic?
152- # Diffusers: broadcast_shape[1] = -1 (frame, height, width dim)
153- # Actually scale_factors is (8, 32, 32) corresponding to (F, H, W) i.e. dim 1 of latent_coords
154220 scale_tensor = scale_tensor .reshape (1 , 3 , 1 , 1 )
155221 pixel_coords = latent_coords * scale_tensor
156222
@@ -260,11 +326,6 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
260326
261327 # Padding if needed
262328 if self .dim % num_rope_elems != 0 :
263- # Diffusers logic: pad with ones/zeros
264- # JAX requires careful padding
265- # But here we computed freqs for `steps = self.dim // num_rope_elems`
266- # So we have `steps * num_rope_elems` elements currently.
267- # If mismatch, we pad.
268329 curr_dim = cos_freqs .shape [- 1 ]
269330 pad_amt = self .dim - curr_dim
270331 if pad_amt > 0 :
@@ -274,42 +335,32 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
274335 sin_freqs = jnp .concatenate ([sin_padding , sin_freqs ], axis = - 1 )
275336
276337 elif self .rope_type == "split" :
277- # [B, N, D//2] -> [B, N, D//2]
278- # Padding first? Diffusers:
279- # expected_freqs = self.dim // 2
280- # current_freqs = freqs.shape[-1]
281- # pad_size = expected_freqs - current_freqs
282- # if pad != 0: pad (before logic?)
283- # Actually Diffusers code:
284- # cos_freq = freqs.cos(), sin_freq = freqs.sin()
285- # if pad: concatenate([pad, cos_freq], axis=-1)
286- # THEN reshape to multi-head?
287-
288- curr_dim = cos_freqs .shape [- 1 ]
338+ # Cos/Sin
339+ cos_freq = jnp .cos (freqs )
340+ sin_freq = jnp .sin (freqs )
341+
342+ curr_dim = cos_freq .shape [- 1 ]
289343 expected_dim = self .dim // 2
290344 pad_size = expected_dim - curr_dim
291345
292346 if pad_size > 0 :
293- cos_padding = jnp .ones ((* cos_freqs .shape [:- 1 ], pad_size ), dtype = cos_freqs .dtype )
294- sin_padding = jnp .zeros ((* sin_freqs .shape [:- 1 ], pad_size ), dtype = sin_freqs .dtype )
295- cos_freqs = jnp .concatenate ([cos_padding , cos_freqs ], axis = - 1 )
296- sin_freqs = jnp .concatenate ([sin_padding , sin_freqs ], axis = - 1 )
297-
298- # Reshape for multi-head?
299- # Diffusers:
300- # cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
301- # swapaxes(1, 2) -> (B, H, T, D//2)
302- # Here: input `coords` was flattened tokens (N).
303- # We assume N = Time?
304- # Wait, `prepare_video_coords` flattens all patches (T*H*W).
305- # So N = T*H*W.
306- # If `rope_type="split"`, does it imply specific Time-Head structure?
307- # LTX-2 `transformer_ltx2.py` in Diffusers passes `rope_type="interleaved"` by default.
308- # Split is mostly for specific attention optimizations.
309- # I will skip the complex reshape logic for now unless requested,
310- # as standard flow is interleaved.
311- # But I should keep the frequency generation logic consistent.
312- pass
347+ cos_padding = jnp .ones ((* cos_freq .shape [:- 1 ], pad_size ), dtype = cos_freq .dtype )
348+ sin_padding = jnp .zeros ((* sin_freq .shape [:- 1 ], pad_size ), dtype = sin_freq .dtype )
349+ cos_freq = jnp .concatenate ([cos_padding , cos_freq ], axis = - 1 )
350+ sin_freq = jnp .concatenate ([sin_padding , sin_freq ], axis = - 1 )
351+
352+ # Reshape freqs to be compatible with multi-head attention
353+ # Diffusers: cos_freq.reshape(b, t, self.num_attention_heads, -1) -> swapaxes(1, 2)
354+ # [B, S, D//2] -> [B, S, H, dim_head//2] -> [B, H, S, dim_head//2]
355+
356+ b = cos_freq .shape [0 ]
357+ s = cos_freq .shape [1 ]
358+
359+ # We need to know H. `LTX2RotaryPosEmbed` has `num_attention_heads`.
360+ h = self .num_attention_heads
361+
362+ cos_freqs = cos_freq .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
363+ sin_freqs = sin_freq .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
313364
314365 return cos_freqs , sin_freqs
315366
@@ -330,8 +381,10 @@ def __init__(
330381 eps : float = 1e-6 ,
331382 dtype : DType = jnp .float32 ,
332383 attention_kernel : str = "flash" ,
384+ rope_type : str = "interleaved" ,
333385 ):
334386 self .heads = heads
387+ self .rope_type = rope_type
335388 self .dim_head = dim_head
336389 self .inner_dim = dim_head * heads
337390 self .dropout_rate = dropout
@@ -387,12 +440,26 @@ def __call__(
387440
388441 # 3. Apply RoPE to tensors of shape [B, S, InnerDim]
389442 # Frequencies are shape [B, S, InnerDim]
443+ # 3. Apply RoPE
390444 if rotary_emb is not None :
391- query = apply_rotary_emb (query , rotary_emb )
392- if k_rotary_emb is not None :
393- key = apply_rotary_emb (key , k_rotary_emb )
394- elif encoder_hidden_states is None :
395- key = apply_rotary_emb (key , rotary_emb )
445+ if hasattr (self , "rope_type" ) and self .rope_type == "split" :
446+ # Split RoPE: passing full freqs [B, H, S, D//2]
447+ # apply_split_rotary_emb handles reshaping query/key
448+
449+ query = apply_split_rotary_emb (query , rotary_emb )
450+
451+ if k_rotary_emb is not None :
452+ key = apply_split_rotary_emb (key , k_rotary_emb )
453+ elif encoder_hidden_states is None :
454+ key = apply_split_rotary_emb (key , rotary_emb )
455+
456+ else :
457+ # Interleaved (Default)
458+ query = apply_rotary_emb (query , rotary_emb )
459+ if k_rotary_emb is not None :
460+ key = apply_rotary_emb (key , k_rotary_emb )
461+ elif encoder_hidden_states is None :
462+ key = apply_rotary_emb (key , rotary_emb )
396463
397464 # 4. Attention
398465 # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
0 commit comments