Skip to content

Commit 7b5d4a0

Browse files
committed
Attention
1 parent 1f7888b commit 7b5d4a0

1 file changed

Lines changed: 181 additions & 0 deletions

File tree

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import torch
19+
import numpy as np
20+
import jax
21+
import jax.numpy as jnp
22+
from flax import nnx
23+
import pandas as pd
24+
25+
# --- 1. Reference PyTorch Model (Minimal LTX-2 Logic) ---
26+
class PytorchLTX2Attention(torch.nn.Module):
27+
def __init__(self, query_dim, context_dim, heads, dim_head):
28+
super().__init__()
29+
inner_dim = dim_head * heads
30+
self.heads = heads
31+
self.dim_head = dim_head
32+
33+
# LTX-2: RMSNorm on full inner_dim
34+
self.q_norm = torch.nn.RMSNorm(inner_dim, eps=1e-6)
35+
self.k_norm = torch.nn.RMSNorm(inner_dim, eps=1e-6)
36+
37+
# LTX-2: Linear layers with bias=True
38+
self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
39+
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
40+
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
41+
42+
self.to_out = torch.nn.Sequential(
43+
torch.nn.Linear(inner_dim, query_dim, bias=True),
44+
torch.nn.Identity()
45+
)
46+
47+
def forward(self, x, context=None):
48+
q = self.to_q(x)
49+
ctx = x if context is None else context
50+
k = self.to_k(ctx)
51+
v = self.to_v(ctx)
52+
53+
# Norms (Key check for LTX-2 vs Flux)
54+
q_norm = self.q_norm(q)
55+
k_norm = self.k_norm(k)
56+
57+
# Reshape
58+
b, s_q, _ = q.shape
59+
_, s_kv, _ = k.shape
60+
61+
q_h = q_norm.view(b, s_q, self.heads, self.dim_head).transpose(1, 2)
62+
k_h = k_norm.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
63+
v_h = v.view(b, s_kv, self.heads, self.dim_head).transpose(1, 2)
64+
65+
# Attention
66+
out = torch.nn.functional.scaled_dot_product_attention(q_h, k_h, v_h, dropout_p=0.0)
67+
out = out.transpose(1, 2).reshape(b, s_q, -1)
68+
69+
return self.to_out(out), (q, k, v, q_norm, k_norm, out) # Return intermediates
70+
71+
# --- 2. Import JAX Model ---
72+
from ..models.ltx2.attention_ltx2 import LTX2Attention
73+
74+
class LTX2ParityTest(unittest.TestCase):
75+
76+
def setUp(self):
77+
self.B, self.S, self.D = 1, 16, 64
78+
self.heads = 4
79+
self.dim_head = 16
80+
self.context_dim = 64
81+
82+
torch.manual_seed(0)
83+
self.rng = nnx.Rngs(0)
84+
85+
# Inputs
86+
self.np_x = np.random.randn(self.B, self.S, self.D).astype(np.float32)
87+
88+
def _init_and_sync_models(self):
89+
"""Initializes both models and copies PyTorch weights to JAX."""
90+
pt_model = PytorchLTX2Attention(self.D, self.context_dim, self.heads, self.dim_head)
91+
pt_model.eval()
92+
93+
jax_model = LTX2Attention(
94+
query_dim=self.D, heads=self.heads, dim_head=self.dim_head, context_dim=self.context_dim,
95+
rngs=self.rng, attention_kernel="dot_product"
96+
)
97+
98+
# Weight Copy Logic
99+
def copy_linear(jax_layer, pt_layer):
100+
jax_layer.kernel.value = jnp.array(pt_layer.weight.detach().numpy().T)
101+
jax_layer.bias.value = jnp.array(pt_layer.bias.detach().numpy())
102+
103+
def copy_norm(jax_layer, pt_layer):
104+
jax_layer.scale.value = jnp.array(pt_layer.weight.detach().numpy())
105+
106+
copy_linear(jax_model.to_q, pt_model.to_q)
107+
copy_linear(jax_model.to_k, pt_model.to_k)
108+
copy_linear(jax_model.to_v, pt_model.to_v)
109+
copy_linear(jax_model.to_out, pt_model.to_out[0])
110+
copy_norm(jax_model.norm_q, pt_model.q_norm)
111+
copy_norm(jax_model.norm_k, pt_model.k_norm)
112+
113+
return pt_model, jax_model
114+
115+
def test_parity_strict(self):
116+
"""Standard Parity Test (Assertion)."""
117+
pt_model, jax_model = self._init_and_sync_models()
118+
119+
with torch.no_grad():
120+
pt_out, _ = pt_model(torch.from_numpy(self.np_x))
121+
122+
jax_out = jax_model(jnp.array(self.np_x))
123+
124+
np.testing.assert_allclose(
125+
pt_out.numpy(), jax_out, atol=1e-5,
126+
err_msg="Strict Parity Failed: Outputs mismatch > 1e-5"
127+
)
128+
print("\n[PASS] Strict Parity Test passed.")
129+
130+
def test_layer_wise_stats(self):
131+
"""Diagnostic Test: Prints Layer-wise stats."""
132+
pt_model, jax_model = self._init_and_sync_models()
133+
134+
# 1. Run PyTorch (Get Intermediates)
135+
with torch.no_grad():
136+
pt_out, (pt_q, pt_k, pt_v, pt_qn, pt_kn, pt_attn) = pt_model(torch.from_numpy(self.np_x))
137+
138+
# 2. Run JAX Step-by-Step (Manual Re-run to get intermediates)
139+
x = jnp.array(self.np_x)
140+
jax_q = jax_model.to_q(x)
141+
jax_k = jax_model.to_k(x) # Self-attn
142+
jax_v = jax_model.to_v(x)
143+
144+
jax_qn = jax_model.norm_q(jax_q)
145+
jax_kn = jax_model.norm_k(jax_k)
146+
147+
# JAX Reshape & Attn
148+
b, s, _ = jax_qn.shape
149+
q_h = jax_qn.reshape(b, s, self.heads, self.dim_head).reshape(b, s, -1)
150+
k_h = jax_kn.reshape(b, s, self.heads, self.dim_head).reshape(b, s, -1)
151+
v_h = jax_v.reshape(b, s, self.heads, self.dim_head).reshape(b, s, -1)
152+
153+
jax_attn = jax_model.attention_op.apply_attention(q_h, k_h, v_h)
154+
jax_out = jax_model.to_out(jax_attn)
155+
156+
# 3. Compare Stats
157+
stats = []
158+
def add_stat(name, pt_t, jax_t):
159+
pt_val = pt_t.numpy() if isinstance(pt_t, torch.Tensor) else pt_t
160+
jax_val = np.array(jax_t)
161+
stats.append({
162+
"Layer": name,
163+
"PT Mean": f"{pt_val.mean():.4f}",
164+
"JAX Mean": f"{jax_val.mean():.4f}",
165+
"Diff (Mean L1)": f"{np.abs(pt_val - jax_val).mean():.6f}"
166+
})
167+
168+
add_stat("Query Proj", pt_q, jax_q)
169+
add_stat("Key Proj", pt_k, jax_k)
170+
add_stat("Value Proj", pt_v, jax_v)
171+
add_stat("Query Norm", pt_qn, jax_qn)
172+
add_stat("Key Norm", pt_kn, jax_kn)
173+
add_stat("Attn Output", pt_attn, jax_attn)
174+
add_stat("Final Output", pt_out, jax_out)
175+
176+
df = pd.DataFrame(stats)
177+
print("\n[DIAGNOSTIC] Layer-wise Stats:")
178+
print(df.to_string(index=False))
179+
180+
if __name__ == "__main__":
181+
unittest.main()

0 commit comments

Comments
 (0)