2121import jax .numpy as jnp
2222from flax import nnx
2323import pandas as pd
24+ from jax .sharding import Mesh
2425
25- # Set JAX to use float32 for precision checks
26+ # Set JAX to use float32 for higher precision checks
2627jax .config .update ("jax_default_matmul_precision" , "float32" )
2728
2829# ==========================================
@@ -44,11 +45,11 @@ def forward(self, ids):
4445 num_axes = ids .shape [- 1 ]
4546 dim_per_axis = self .dim // num_axes
4647
47- freqs_list = []
4848 # Standard RoPE frequencies: theta^(-2i/d)
4949 freq_indices = torch .arange (0 , dim_per_axis , 2 , dtype = torch .float32 )
5050 inv_freq = 1.0 / (self .theta ** (freq_indices / dim_per_axis ))
5151
52+ freqs_list = []
5253 for i in range (num_axes ):
5354 axis_pos = ids [..., i ] # [B, S]
5455 # Outer product: [B, S, 1] * [1, 1, D/2] -> [B, S, D/2]
@@ -65,7 +66,7 @@ def forward(self, ids):
6566 cos = torch .repeat_interleave (cos , 2 , dim = - 1 )
6667 sin = torch .repeat_interleave (sin , 2 , dim = - 1 )
6768
68- # Add head dim: [B, S, 1, D]
69+ # Add head dim for broadcasting : [B, S, 1, D]
6970 return cos .unsqueeze (2 ), sin .unsqueeze (2 )
7071
7172
@@ -77,7 +78,7 @@ def apply_rotary_emb_pt(x, cos, sin):
7778 x1 , x2 = x_reshaped .unbind (- 1 )
7879 x_rotated = torch .stack ((- x2 , x1 ), dim = - 1 ).view (b , h , s , d )
7980
80- # Cast to float32 for rotation parity
81+ # Cast to float32 for rotation parity with JAX
8182 orig_dtype = x .dtype
8283 x_f32 = x .to (torch .float32 )
8384 rot_f32 = x_rotated .to (torch .float32 )
@@ -108,7 +109,7 @@ def __init__(self, query_dim, context_dim, heads, dim_head):
108109 torch .nn .Identity ()
109110 )
110111
111- def forward (self , x , context = None , q_rope = None , k_rope = None ):
112+ def forward (self , x , context = None , q_rope = None , k_rope = None , mask = None ):
112113 q = self .to_q (x )
113114 ctx = x if context is None else context
114115 k = self .to_k (ctx )
@@ -133,15 +134,19 @@ def forward(self, x, context=None, q_rope=None, k_rope=None):
133134 k_cos , k_sin = k_rope
134135 k_h = apply_rotary_emb_pt (k_h , k_cos , k_sin )
135136
136- out = torch .nn .functional .scaled_dot_product_attention (q_h , k_h , v_h , dropout_p = 0.0 )
137+ # PyTorch Attention expects mask in [B, H, S, S] or additive
138+ out = torch .nn .functional .scaled_dot_product_attention (
139+ q_h , k_h , v_h , attn_mask = mask , dropout_p = 0.0
140+ )
141+
137142 out = out .transpose (1 , 2 ).reshape (b , s_q , - 1 )
138143
139- return self .to_out (out )
144+ return self .to_out (out ), ( q , k , v , q_norm , k_norm , out )
140145
141146# ==========================================
142147# 2. JAX Imports
143148# ==========================================
144- from . .models .ltx2 .attention_ltx2 import LTX2Attention , LTX2RotaryPosEmbed
149+ from maxdiffusion .models .ltx2 .attention_ltx2 import LTX2Attention , LTX2RotaryPosEmbed
145150
146151class LTX2AttentionTest (unittest .TestCase ):
147152
@@ -192,48 +197,54 @@ def copy_norm(jax_layer, pt_layer):
192197 return pt_model , jax_model
193198
194199 # ------------------------------------------
195- # 1. RoPE Frequency Parity Test
200+ # 1. Output Shape Tests
196201 # ------------------------------------------
197- def test_rope_frequency_parity (self ):
198- """
199- Verifies that LTX2RotaryPosEmbed (JAX) generates the EXACT same
200- frequencies as the PyTorch reference for a given input ID set.
201- """
202- dim = 60 # Divisible by 3 for 3D test
202+ def test_shapes (self ):
203+ """Verifies JAX model handles Video (3D) and Audio (1D) shapes."""
204+ model = LTX2Attention (64 , 4 , 16 , 64 , rngs = self .rng , attention_kernel = "dot_product" )
205+
206+ # Video: [B, S, D]
207+ x_vid = jnp .zeros ((1 , 128 , 64 ))
208+ out_vid = model (x_vid )
209+ self .assertEqual (out_vid .shape , (1 , 128 , 64 ))
203210
211+ # Audio Cross-Attn: [B, S_vid, D] -> [B, S_aud, D]
212+ x_aud = jnp .zeros ((1 , 32 , 64 ))
213+ out_cross = model (x_vid , encoder_hidden_states = x_aud )
214+ self .assertEqual (out_cross .shape , (1 , 128 , 64 ))
215+ print ("\n [PASS] Shape Tests Passed." )
216+
217+ # ------------------------------------------
218+ # 2. RoPE Frequency Parity
219+ # ------------------------------------------
220+ def test_rope_frequency_parity (self ):
221+ """Verifies JAX RoPE Frequencies match PyTorch."""
222+ dim = 60
204223 rope_pt = PytorchLTX2RotaryPosEmbed (dim = dim )
205224 rope_jax = LTX2RotaryPosEmbed (dim = dim )
206225
207- # Create random IDs [B, S, 3]
208226 np_ids = np .random .randint (0 , 100 , (2 , 16 , 3 )).astype (np .float32 )
209227
210- # Run PyTorch
211228 pt_cos , pt_sin = rope_pt (torch .from_numpy (np_ids ))
212-
213- # Run JAX
214229 jax_cos , jax_sin = rope_jax (jnp .array (np_ids ))
215230
216- # Compare
217- pt_cos_np = pt_cos .numpy ()
218- jax_cos_np = np .array (jax_cos )
219-
220231 # 1e-5 tolerance for freq generation math
221- np .testing .assert_allclose (pt_cos_np , jax_cos_np , atol = 1e-5 )
232+ np .testing .assert_allclose (pt_cos . numpy (), np . array ( jax_cos ) , atol = 1e-5 )
222233 np .testing .assert_allclose (pt_sin .numpy (), np .array (jax_sin ), atol = 1e-5 )
223- print ("\n [PASS] RoPE Frequency Generation matches PyTorch ." )
234+ print ("[PASS] RoPE Frequency Parity Verified ." )
224235
225236 # ------------------------------------------
226- # 2 . Strict Parity Test (Full Model)
237+ # 3 . Strict Parity Test (Full Model, BF16 )
227238 # ------------------------------------------
228239 def test_parity_bf16_strict (self ):
229- """Checks if JAX(TPU) matches PyTorch(CPU) in BF16."""
240+ """Checks if JAX matches PyTorch in BF16."""
230241 pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .bfloat16 )
231242
232243 pt_in = torch .from_numpy (self .np_x ).to (device = "cpu" , dtype = torch .bfloat16 )
233244 jax_in = jnp .array (self .np_x ).astype (jnp .bfloat16 )
234245
235246 with torch .no_grad ():
236- pt_out = pt_model (pt_in )
247+ pt_out , _ = pt_model (pt_in )
237248
238249 jax_out = jax_model (jax_in )
239250
@@ -247,40 +258,18 @@ def test_parity_bf16_strict(self):
247258 print ("\n [PASS] BF16 Strict Parity Test passed." )
248259
249260 # ------------------------------------------
250- # 3 . Layer-wise Stats (Corrected Shape Logic)
261+ # 4 . Layer-wise Diagnostics
251262 # ------------------------------------------
252263 def test_layer_wise_stats (self ):
253- """Prints diagnostic stats for every layer."""
264+ """Prints diagnostic stats for every layer (Bfloat16) ."""
254265 pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .bfloat16 )
255266
256267 pt_in = torch .from_numpy (self .np_x ).to (device = "cpu" , dtype = torch .bfloat16 )
257268 jax_in = jnp .array (self .np_x ).astype (jnp .bfloat16 )
258269
259- # 1. Run PyTorch Step-by-Step
270+ # 1. Run PyTorch Step-by-Step (Get intermediates)
260271 with torch .no_grad ():
261- # Projections
262- pt_q = pt_model .to_q (pt_in )
263- pt_k = pt_model .to_k (pt_in )
264- pt_v = pt_model .to_v (pt_in )
265-
266- # Norms
267- pt_qn = pt_model .q_norm (pt_q )
268- pt_kn = pt_model .k_norm (pt_k )
269-
270- # Attention Prep (Reshape -> Transpose)
271- b , s , _ = pt_qn .shape
272- pt_q_h = pt_qn .view (b , s , self .heads , self .dim_head ).transpose (1 , 2 )
273- pt_k_h = pt_kn .view (b , s , self .heads , self .dim_head ).transpose (1 , 2 )
274- pt_v_h = pt_v .view (b , s , self .heads , self .dim_head ).transpose (1 , 2 )
275-
276- # Attention Op
277- pt_attn_out = torch .nn .functional .scaled_dot_product_attention (pt_q_h , pt_k_h , pt_v_h )
278-
279- # Reshape Back
280- pt_attn_flat = pt_attn_out .transpose (1 , 2 ).reshape (b , s , - 1 )
281-
282- # Output
283- pt_out = pt_model .to_out (pt_attn_flat )
272+ pt_out , (pt_q , pt_k , pt_v , pt_qn , pt_kn , pt_attn ) = pt_model (pt_in )
284273
285274 # 2. Run JAX Step-by-Step
286275 jax_q = jax_model .to_q (jax_in )
@@ -290,45 +279,133 @@ def test_layer_wise_stats(self):
290279 jax_qn = jax_model .norm_q (jax_q )
291280 jax_kn = jax_model .norm_k (jax_k )
292281
293- # Attention Op: Pass [B, S, Inner_Dim] directly
294- # The LTX2Attention.__call__ flattens inputs before calling apply_attention,
295- # so we pass the flattened (Inner_Dim) tensors here.
282+ # Pass 3D tensors [B, S, Inner_Dim] directly to attention op
283+ # NNXAttentionOp handles the internal logic for the kernel
296284 jax_attn = jax_model .attention_op .apply_attention (jax_qn , jax_kn , jax_v )
297285 jax_out = jax_model .to_out (jax_attn )
298286
299- # 3. Print Comparison Table
287+ # 3. Build & Print Comparison Table
300288 stats = []
301289 def add_stat (name , pt_t , jax_t ):
302- pt_val = pt_t .float ().numpy () if isinstance (pt_t , torch .Tensor ) else pt_t
290+ # Ensure pt_t is a tensor before calling .float().numpy()
291+ if isinstance (pt_t , torch .Tensor ):
292+ pt_val = pt_t .float ().numpy ()
293+ else :
294+ pt_val = pt_t
295+
303296 jax_val = np .array (jax_t , dtype = np .float32 )
297+
304298 stats .append ({
305299 "Layer" : name ,
306300 "PT Mean" : f"{ pt_val .mean ():.4f} " ,
307301 "JAX Mean" : f"{ jax_val .mean ():.4f} " ,
308302 "PT Min" : f"{ pt_val .min ():.4f} " ,
309303 "JAX Min" : f"{ jax_val .min ():.4f} " ,
310- "PT Max" : f"{ pt_val .max ():.4f} " ,
311- "JAX Max" : f"{ jax_val .max ():.4f} " ,
312- "Diff (Mean L1)" : f"{ np .abs (pt_val - jax_val ).mean ():.6f} "
304+ "Diff (L1)" : f"{ np .abs (pt_val - jax_val ).mean ():.6f} "
313305 })
314306
315307 add_stat ("Query Proj" , pt_q , jax_q )
316308 add_stat ("Key Proj" , pt_k , jax_k )
317309 add_stat ("Value Proj" , pt_v , jax_v )
318310 add_stat ("Query Norm" , pt_qn , jax_qn )
319311 add_stat ("Key Norm" , pt_kn , jax_kn )
320- add_stat ("Attn Output" , pt_attn_flat , jax_attn )
312+ add_stat ("Attn Output" , pt_attn , jax_attn )
321313 add_stat ("Final Output" , pt_out , jax_out )
322314
323315 df = pd .DataFrame (stats )
324- print ("\n [DIAGNOSTIC] Layer-wise Stats (CPU vs TPU BFloat16) :" )
316+ print ("\n [DIAGNOSTIC] Layer-wise Stats:" )
325317 print (df .to_string (index = False ))
318+
326319 # ------------------------------------------
327- # 4 . Cross-Attention + RoPE Integration
320+ # 5 . Cross-Attention + RoPE Integration
328321 # ------------------------------------------
329322 def test_cross_attn_rope_integration (self ):
330- """Verifies Video->Audio cross-attention with RoPE."""
323+ """Verifies Video->Audio cross-attention with RoPE (Float32) ."""
331324 S_Q , S_KV = 16 , 20
332325 pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .float32 )
333326
334- np_x = np .random .randn (self .B , S_Q , self .D ).astype (np .float32 )
327+ np_x = np .random .randn (self .B , S_Q , self .D ).astype (np .float32 )
328+ np_ctx = np .random .randn (self .B , S_KV , self .D ).astype (np .float32 )
329+
330+ rope_gen_pt = PytorchLTX2RotaryPosEmbed (dim = 64 )
331+
332+ ids_q = torch .randint (0 , 100 , (self .B , S_Q , 1 ))
333+ ids_k = torch .randint (0 , 100 , (self .B , S_KV , 1 ))
334+
335+ q_cos_pt , q_sin_pt = rope_gen_pt (ids_q .float ())
336+ k_cos_pt , k_sin_pt = rope_gen_pt (ids_k .float ())
337+
338+ def prep_pt (c , s ):
339+ c = c .view (self .B , - 1 , self .heads , self .dim_head ).transpose (1 , 2 )
340+ s = s .view (self .B , - 1 , self .heads , self .dim_head ).transpose (1 , 2 )
341+ return c , s
342+
343+ pt_q_rope = prep_pt (q_cos_pt , q_sin_pt )
344+ pt_k_rope = prep_pt (k_cos_pt , k_sin_pt )
345+
346+ with torch .no_grad ():
347+ pt_out , _ = pt_model (
348+ torch .from_numpy (np_x ),
349+ context = torch .from_numpy (np_ctx ),
350+ q_rope = pt_q_rope ,
351+ k_rope = pt_k_rope
352+ )
353+
354+ jax_q_rope = (jnp .array (q_cos_pt .numpy ()), jnp .array (q_sin_pt .numpy ()))
355+ jax_k_rope = (jnp .array (k_cos_pt .numpy ()), jnp .array (k_sin_pt .numpy ()))
356+
357+ jax_out = jax_model (
358+ jnp .array (np_x ),
359+ encoder_hidden_states = jnp .array (np_ctx ),
360+ rotary_emb = jax_q_rope ,
361+ k_rotary_emb = jax_k_rope
362+ )
363+
364+ diff = np .abs (pt_out .numpy () - np .array (jax_out )).max ()
365+ print (f"\n [Cross-Attn + RoPE] Max Diff: { diff :.6f} " )
366+ np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), atol = 1e-5 )
367+ print ("[PASS] Cross-Attention with RoPE Parity Verified." )
368+
369+ # ------------------------------------------
370+ # 6. Attention Mask Parity
371+ # ------------------------------------------
372+ def test_attention_mask_parity (self ):
373+ """
374+ Verifies attention masks (padding) work identically using FLASH kernel.
375+ Flash kernel in attention_flax expects a multiplicative mask [B, S],
376+ while PyTorch SDPA expects an additive mask broadcastable to [B,H,S,S].
377+ """
378+ pt_model , jax_model = self ._init_and_sync_models (dtype = jnp .float32 )
379+
380+ # Switch JAX model to use flash attention for this test
381+ jax_model .attention_op .attention_kernel = "flash"
382+ jax_model .attention_op .mesh = Mesh (jax .devices (), ('x' ,))
383+
384+ np_x = np .random .randn (self .B , self .S , self .D ).astype (np .float32 )
385+
386+ # Create mask pattern: 1 = keep, 0 = mask out
387+ # Shape: [B, S]
388+ mask_pattern_np = np .random .randint (0 , 2 , (self .B , self .S )).astype (np .float32 )
389+
390+ # PyTorch needs ADDITIVE mask: 0 for keep, -inf for mask out
391+ # Broadcastable to [B, H, S_q, S_kv]: [1, 1, 1, 16] is ok for B=1,H=4,S=16
392+ pt_mask_additive = torch .from_numpy ((1.0 - mask_pattern_np ) * - 1e9 )[:, None , None , :]
393+
394+ # JAX Flash attention needs MULTIPLICATIVE mask: [1, 16]
395+ jax_mask_multiplicative = jnp .array (mask_pattern_np )
396+
397+ # PyTorch
398+ with torch .no_grad ():
399+ pt_out , _ = pt_model (torch .from_numpy (np_x ), mask = pt_mask_additive )
400+
401+ # JAX
402+ jax_out = jax_model (
403+ jnp .array (np_x ),
404+ attention_mask = jax_mask_multiplicative
405+ )
406+
407+ np .testing .assert_allclose (pt_out .numpy (), np .array (jax_out ), atol = 1e-5 )
408+ print ("[PASS] Attention Mask Parity Verified." )
409+
410+ if __name__ == "__main__" :
411+ unittest .main ()
0 commit comments