@@ -59,49 +59,48 @@ def apply_split_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
5959 """
6060 Applies Split RoPE to input x.
6161 Logic matches Diffusers apply_split_rotary_emb.
62-
62+
6363 Args:
64- x: Input tensor.
64+ x: Input tensor.
6565 If ndim=3 [B, S, D], it will be reshaped to satisfy cos/sin shapes if needed.
66- freqs: Tuple of (cos, sin).
66+ freqs: Tuple of (cos, sin).
6767 Expected to be [B, H, S, D//2] if coming from LTX2RotaryPosEmbed(split).
6868 """
6969 cos , sin = freqs
70-
70+
7171 x_dtype = x .dtype
7272 needed_reshape = False
7373 original_shape = x .shape
74-
74+
7575 if x .ndim != 4 and cos .ndim == 4 :
76- b = x .shape [0 ]
77- h , s , r = cos .shape [1 ], cos .shape [2 ], cos .shape [3 ]
78- x = x .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
79- needed_reshape = True
80-
76+ b = x .shape [0 ]
77+ h , s , r = cos .shape [1 ], cos .shape [2 ], cos .shape [3 ]
78+ x = x .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
79+ needed_reshape = True
80+
8181 last_dim = x .shape [- 1 ]
8282 r = last_dim // 2
83-
83+
8484 split_x = x .reshape (* x .shape [:- 1 ], 2 , r )
85-
85+
8686 first_x = split_x [..., 0 , :]
8787 second_x = split_x [..., 1 , :]
88-
88+
8989 cos_u = jnp .expand_dims (cos , axis = - 2 )
9090 sin_u = jnp .expand_dims (sin , axis = - 2 )
91-
91+
9292 out = split_x * cos_u
93-
93+
9494 out_first = out [..., 0 , :] - second_x * sin_u .squeeze (- 2 )
9595 out_second = out [..., 1 , :] + first_x * sin_u .squeeze (- 2 )
96-
96+
9797 out = jnp .stack ([out_first , out_second ], axis = - 2 )
9898 out = out .reshape (* out .shape [:- 2 ], last_dim )
99-
99+
100100 if needed_reshape :
101- out = out .transpose (0 , 2 , 1 , 3 ).reshape (original_shape )
102-
103- return out .astype (x_dtype )
101+ out = out .transpose (0 , 2 , 1 , 3 ).reshape (original_shape )
104102
103+ return out .astype (x_dtype )
105104
106105
107106class LTX2RotaryPosEmbed (nnx .Module ):
@@ -308,11 +307,10 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
308307 cos_freq = jnp .concatenate ([cos_padding , cos_freq ], axis = - 1 )
309308 sin_freq = jnp .concatenate ([sin_padding , sin_freq ], axis = - 1 )
310309
311-
312310 b = cos_freq .shape [0 ]
313311 s = cos_freq .shape [1 ]
314312 h = self .num_attention_heads
315-
313+
316314 cos_freqs = cos_freq .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
317315 sin_freqs = sin_freq .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
318316
@@ -352,8 +350,12 @@ def __init__(
352350 self .to_v = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
353351
354352 # 2. Normalization (Applied to full inner_dim, NOT per-head)
355- self .norm_q = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs )
356- self .norm_k = nnx .RMSNorm (self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs )
353+ self .norm_q = nnx .RMSNorm (
354+ self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs
355+ )
356+ self .norm_k = nnx .RMSNorm (
357+ self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs
358+ )
357359
358360 # 3. Output
359361 self .to_out = nnx .Linear (self .inner_dim , query_dim , use_bias = out_bias , rngs = rngs , dtype = dtype )
@@ -397,23 +399,23 @@ def __call__(
397399 # 3. Apply RoPE
398400 if rotary_emb is not None :
399401 if hasattr (self , "rope_type" ) and self .rope_type == "split" :
400- # Split RoPE: passing full freqs [B, H, S, D//2]
401- # apply_split_rotary_emb handles reshaping query/key
402-
403- query = apply_split_rotary_emb (query , rotary_emb )
404-
405- if k_rotary_emb is not None :
406- key = apply_split_rotary_emb (key , k_rotary_emb )
407- elif encoder_hidden_states is None :
408- key = apply_split_rotary_emb (key , rotary_emb )
409-
402+ # Split RoPE: passing full freqs [B, H, S, D//2]
403+ # apply_split_rotary_emb handles reshaping query/key
404+
405+ query = apply_split_rotary_emb (query , rotary_emb )
406+
407+ if k_rotary_emb is not None :
408+ key = apply_split_rotary_emb (key , k_rotary_emb )
409+ elif encoder_hidden_states is None :
410+ key = apply_split_rotary_emb (key , rotary_emb )
411+
410412 else :
411- # Interleaved (Default)
412- query = apply_rotary_emb (query , rotary_emb )
413- if k_rotary_emb is not None :
414- key = apply_rotary_emb (key , k_rotary_emb )
415- elif encoder_hidden_states is None :
416- key = apply_rotary_emb (key , rotary_emb )
413+ # Interleaved (Default)
414+ query = apply_rotary_emb (query , rotary_emb )
415+ if k_rotary_emb is not None :
416+ key = apply_rotary_emb (key , k_rotary_emb )
417+ elif encoder_hidden_states is None :
418+ key = apply_rotary_emb (key , rotary_emb )
417419
418420 # 4. Attention
419421 # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
0 commit comments