1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ import unittest
18+ import torch
19+ import numpy as np
20+ import jax
21+ import jax .numpy as jnp
22+ from flax import nnx
23+ import pandas as pd
24+
25+ # --- 1. Reference PyTorch Model (Minimal LTX-2 Logic) ---
26+ class PytorchLTX2Attention (torch .nn .Module ):
27+ def __init__ (self , query_dim , context_dim , heads , dim_head ):
28+ super ().__init__ ()
29+ inner_dim = dim_head * heads
30+ self .heads = heads
31+ self .dim_head = dim_head
32+
33+ # LTX-2: RMSNorm on full inner_dim
34+ self .q_norm = torch .nn .RMSNorm (inner_dim , eps = 1e-6 )
35+ self .k_norm = torch .nn .RMSNorm (inner_dim , eps = 1e-6 )
36+
37+ # LTX-2: Linear layers with bias=True
38+ self .to_q = torch .nn .Linear (query_dim , inner_dim , bias = True )
39+ self .to_k = torch .nn .Linear (context_dim , inner_dim , bias = True )
40+ self .to_v = torch .nn .Linear (context_dim , inner_dim , bias = True )
41+
42+ self .to_out = torch .nn .Sequential (
43+ torch .nn .Linear (inner_dim , query_dim , bias = True ),
44+ torch .nn .Identity ()
45+ )
46+
47+ def forward (self , x , context = None ):
48+ q = self .to_q (x )
49+ ctx = x if context is None else context
50+ k = self .to_k (ctx )
51+ v = self .to_v (ctx )
52+
53+ # Norms (Key check for LTX-2 vs Flux)
54+ q_norm = self .q_norm (q )
55+ k_norm = self .k_norm (k )
56+
57+ # Reshape
58+ b , s_q , _ = q .shape
59+ _ , s_kv , _ = k .shape
60+
61+ q_h = q_norm .view (b , s_q , self .heads , self .dim_head ).transpose (1 , 2 )
62+ k_h = k_norm .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
63+ v_h = v .view (b , s_kv , self .heads , self .dim_head ).transpose (1 , 2 )
64+
65+ # Attention
66+ out = torch .nn .functional .scaled_dot_product_attention (q_h , k_h , v_h , dropout_p = 0.0 )
67+ out = out .transpose (1 , 2 ).reshape (b , s_q , - 1 )
68+
69+ return self .to_out (out ), (q , k , v , q_norm , k_norm , out ) # Return intermediates
70+
71+ # --- 2. Import JAX Model ---
72+ from ..models .ltx2 .attention_ltx2 import LTX2Attention
73+
74+ class LTX2ParityTest (unittest .TestCase ):
75+
76+ def setUp (self ):
77+ self .B , self .S , self .D = 1 , 16 , 64
78+ self .heads = 4
79+ self .dim_head = 16
80+ self .context_dim = 64
81+
82+ torch .manual_seed (0 )
83+ self .rng = nnx .Rngs (0 )
84+
85+ # Inputs
86+ self .np_x = np .random .randn (self .B , self .S , self .D ).astype (np .float32 )
87+
88+ def _init_and_sync_models (self ):
89+ """Initializes both models and copies PyTorch weights to JAX."""
90+ pt_model = PytorchLTX2Attention (self .D , self .context_dim , self .heads , self .dim_head )
91+ pt_model .eval ()
92+
93+ jax_model = LTX2Attention (
94+ query_dim = self .D , heads = self .heads , dim_head = self .dim_head , context_dim = self .context_dim ,
95+ rngs = self .rng , attention_kernel = "dot_product"
96+ )
97+
98+ # Weight Copy Logic
99+ def copy_linear (jax_layer , pt_layer ):
100+ jax_layer .kernel .value = jnp .array (pt_layer .weight .detach ().numpy ().T )
101+ jax_layer .bias .value = jnp .array (pt_layer .bias .detach ().numpy ())
102+
103+ def copy_norm (jax_layer , pt_layer ):
104+ jax_layer .scale .value = jnp .array (pt_layer .weight .detach ().numpy ())
105+
106+ copy_linear (jax_model .to_q , pt_model .to_q )
107+ copy_linear (jax_model .to_k , pt_model .to_k )
108+ copy_linear (jax_model .to_v , pt_model .to_v )
109+ copy_linear (jax_model .to_out , pt_model .to_out [0 ])
110+ copy_norm (jax_model .norm_q , pt_model .q_norm )
111+ copy_norm (jax_model .norm_k , pt_model .k_norm )
112+
113+ return pt_model , jax_model
114+
115+ def test_parity_strict (self ):
116+ """Standard Parity Test (Assertion)."""
117+ pt_model , jax_model = self ._init_and_sync_models ()
118+
119+ with torch .no_grad ():
120+ pt_out , _ = pt_model (torch .from_numpy (self .np_x ))
121+
122+ jax_out = jax_model (jnp .array (self .np_x ))
123+
124+ np .testing .assert_allclose (
125+ pt_out .numpy (), jax_out , atol = 1e-5 ,
126+ err_msg = "Strict Parity Failed: Outputs mismatch > 1e-5"
127+ )
128+ print ("\n [PASS] Strict Parity Test passed." )
129+
130+ def test_layer_wise_stats (self ):
131+ """Diagnostic Test: Prints Layer-wise stats."""
132+ pt_model , jax_model = self ._init_and_sync_models ()
133+
134+ # 1. Run PyTorch (Get Intermediates)
135+ with torch .no_grad ():
136+ pt_out , (pt_q , pt_k , pt_v , pt_qn , pt_kn , pt_attn ) = pt_model (torch .from_numpy (self .np_x ))
137+
138+ # 2. Run JAX Step-by-Step (Manual Re-run to get intermediates)
139+ x = jnp .array (self .np_x )
140+ jax_q = jax_model .to_q (x )
141+ jax_k = jax_model .to_k (x ) # Self-attn
142+ jax_v = jax_model .to_v (x )
143+
144+ jax_qn = jax_model .norm_q (jax_q )
145+ jax_kn = jax_model .norm_k (jax_k )
146+
147+ # JAX Reshape & Attn
148+ b , s , _ = jax_qn .shape
149+ q_h = jax_qn .reshape (b , s , self .heads , self .dim_head ).reshape (b , s , - 1 )
150+ k_h = jax_kn .reshape (b , s , self .heads , self .dim_head ).reshape (b , s , - 1 )
151+ v_h = jax_v .reshape (b , s , self .heads , self .dim_head ).reshape (b , s , - 1 )
152+
153+ jax_attn = jax_model .attention_op .apply_attention (q_h , k_h , v_h )
154+ jax_out = jax_model .to_out (jax_attn )
155+
156+ # 3. Compare Stats
157+ stats = []
158+ def add_stat (name , pt_t , jax_t ):
159+ pt_val = pt_t .numpy () if isinstance (pt_t , torch .Tensor ) else pt_t
160+ jax_val = np .array (jax_t )
161+ stats .append ({
162+ "Layer" : name ,
163+ "PT Mean" : f"{ pt_val .mean ():.4f} " ,
164+ "JAX Mean" : f"{ jax_val .mean ():.4f} " ,
165+ "Diff (Mean L1)" : f"{ np .abs (pt_val - jax_val ).mean ():.6f} "
166+ })
167+
168+ add_stat ("Query Proj" , pt_q , jax_q )
169+ add_stat ("Key Proj" , pt_k , jax_k )
170+ add_stat ("Value Proj" , pt_v , jax_v )
171+ add_stat ("Query Norm" , pt_qn , jax_qn )
172+ add_stat ("Key Norm" , pt_kn , jax_kn )
173+ add_stat ("Attn Output" , pt_attn , jax_attn )
174+ add_stat ("Final Output" , pt_out , jax_out )
175+
176+ df = pd .DataFrame (stats )
177+ print ("\n [DIAGNOSTIC] Layer-wise Stats:" )
178+ print (df .to_string (index = False ))
179+
180+ if __name__ == "__main__" :
181+ unittest .main ()
0 commit comments