@@ -60,7 +60,7 @@ def forward(self, ids):
6060 cos = torch .repeat_interleave (cos , 2 , dim = - 1 )
6161 sin = torch .repeat_interleave (sin , 2 , dim = - 1 )
6262
63- # CORRECT: Return [B, S, InnerDim] to match JAX/LTX-2 global RoPE
63+ # Return [B, S, InnerDim] to match JAX/LTX-2 global RoPE
6464 return cos , sin
6565
6666
@@ -111,35 +111,37 @@ def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None):
111111 k = self .to_k (ctx )
112112 v = self .to_v (ctx )
113113
114- q = self .q_norm (q )
115- k = self .k_norm (k )
114+ # Keep raw projections for test_layer_wise_stats
115+ q_raw , k_raw = q , k
116+
117+ q_normed = self .q_norm (q )
118+ k_normed = self .k_norm (k )
116119
117- # CORRECT: Apply RoPE globally BEFORE splitting heads
118120 if q_rope is not None :
119121 q_cos , q_sin = q_rope
120- q = apply_rotary_emb_pt (q , q_cos , q_sin )
122+ q_normed = apply_rotary_emb_pt (q_normed , q_cos , q_sin )
121123
122124 if k_rope is not None :
123125 k_cos , k_sin = k_rope
124- k = apply_rotary_emb_pt (k , k_cos , k_sin )
126+ k_normed = apply_rotary_emb_pt (k_normed , k_cos , k_sin )
125127
126128 # Split Heads for Attention
127- b , s_q , _ = q .shape
128- _ , s_kv , _ = k .shape
129- q_h = q .view (b , s_q , self .heads , self .dim_head ).transpose (1 , 2 )
130- k_h = k .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
129+ b , s_q , _ = q_normed .shape
130+ _ , s_kv , _ = k_normed .shape
131+ q_h = q_normed .view (b , s_q , self .heads , self .dim_head ).transpose (1 , 2 )
132+ k_h = k_normed .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
131133 v_h = v .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
132134
133135 out = torch .nn .functional .scaled_dot_product_attention (
134136 q_h , k_h , v_h , attn_mask = mask , dropout_p = 0.0
135137 )
136138 out = out .transpose (1 , 2 ).reshape (b , s_q , - 1 )
137- return self .to_out (out ), (q , k , v , q , k , out ) # Returning normed q/k as placeholder
139+ return self .to_out (out ), (q_raw , k_raw , v , q_normed , k_normed , out )
138140
139141# ==========================================
140142# 2. JAX Imports & Test Suite
141143# ==========================================
142- from maxdiffusion .models .ltx2 .attention_ltx2 import LTX2Attention , LTX2RotaryPosEmbed
144+ from . .models .ltx2 .attention_ltx2 import LTX2Attention , LTX2RotaryPosEmbed
143145
144146class LTX2AttentionTest (unittest .TestCase ):
145147
@@ -191,11 +193,11 @@ def test_shapes(self):
191193 model = LTX2Attention (64 , 4 , 16 , 64 , rngs = self .rng , attention_kernel = "dot_product" )
192194
193195 x_vid = jnp .zeros ((1 , 128 , 64 ))
194- out_vid = model (x_vid , deterministic = True )
196+ out_vid = model (x_vid )
195197 self .assertEqual (out_vid .shape , (1 , 128 , 64 ))
196198
197199 x_aud = jnp .zeros ((1 , 32 , 64 ))
198- out_cross = model (x_vid , encoder_hidden_states = x_aud , deterministic = True )
200+ out_cross = model (x_vid , encoder_hidden_states = x_aud )
199201 self .assertEqual (out_cross .shape , (1 , 128 , 64 ))
200202 print ("\n [PASS] Shape Tests Passed." )
201203
@@ -221,7 +223,7 @@ def test_parity_bf16_strict(self):
221223 with torch .no_grad ():
222224 pt_out , _ = pt_model (pt_in )
223225
224- jax_out = jax_model (jax_in , deterministic = True )
226+ jax_out = jax_model (jax_in )
225227
226228 pt_res = pt_out .float ().numpy ()
227229 jax_res = np .array (jax_out , dtype = np .float32 )
@@ -259,8 +261,12 @@ def add_stat(name, pt_t, jax_t):
259261 jax_val = np .array (jax_t , dtype = np .float32 )
260262 stats .append ({
261263 "Layer" : name ,
264+ "PT Max" : f"{ pt_val .max ():.4f} " ,
265+ "JAX Max" : f"{ jax_val .max ():.4f} " ,
262266 "PT Mean" : f"{ pt_val .mean ():.4f} " ,
263267 "JAX Mean" : f"{ jax_val .mean ():.4f} " ,
268+ "PT Min" : f"{ pt_val .min ():.4f} " ,
269+ "JAX Min" : f"{ jax_val .min ():.4f} " ,
264270 "Diff (L1)" : f"{ np .abs (pt_val - jax_val ).mean ():.6f} "
265271 })
266272
@@ -291,7 +297,6 @@ def test_cross_attn_rope_integration(self):
291297 q_cos_pt , q_sin_pt = rope_gen_pt (ids_q .float ())
292298 k_cos_pt , k_sin_pt = rope_gen_pt (ids_k .float ())
293299
294- # No reshape needed! Passed directly as [B, S, InnerDim]
295300 with torch .no_grad ():
296301 pt_out , _ = pt_model (
297302 torch .from_numpy (np_x ),
@@ -307,8 +312,7 @@ def test_cross_attn_rope_integration(self):
307312 jnp .array (np_x ),
308313 encoder_hidden_states = jnp .array (np_ctx ),
309314 rotary_emb = jax_q_rope ,
310- k_rotary_emb = jax_k_rope ,
311- deterministic = True
315+ k_rotary_emb = jax_k_rope
312316 )
313317
314318 diff = np .abs (pt_out .numpy () - np .array (jax_out )).max ()
@@ -342,8 +346,7 @@ def test_attention_mask_parity(self):
342346 with mesh :
343347 jax_out = jax_model (
344348 jnp .array (np_x ),
345- attention_mask = jax_mask_multiplicative ,
346- deterministic = True
349+ attention_mask = jax_mask_multiplicative
347350 )
348351
349352 diff = np .abs (pt_out .numpy () - np .array (jax_out )).max ()
0 commit comments