Skip to content

Commit 2419540

Browse files
committed
using latest attn test
1 parent 317698e commit 2419540

2 files changed

Lines changed: 83 additions & 47 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -929,11 +929,6 @@ def __call__(
929929
audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1)
930930

931931
batch_size = hidden_states.shape[0]
932-
# print_shape("Model Input hidden_states", hidden_states)
933-
# print_shape("Model Input audio_hidden_states", audio_hidden_states)
934-
# print_shape("Model Input encoder_hidden_states", encoder_hidden_states)
935-
# print_shape("Model Input audio_encoder_hidden_states", audio_encoder_hidden_states)
936-
# print_shape("Model Input timestep", timestep)
937932

938933
# 1. Prepare RoPE positional embeddings
939934
with self.conditional_named_scope("rotary_embedding"):

src/maxdiffusion/tests/ltx2_attention_test.py

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright 2025 Google LLC
2+
Copyright 2026 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -22,47 +22,86 @@
2222
from flax import nnx
2323
import pandas as pd
2424
from jax.sharding import Mesh
25-
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
26-
27-
# Set JAX to use float32 for higher precision checks
28-
jax.config.update("jax_default_matmul_precision", "float32")
25+
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
2926

3027
# ==========================================
3128
# 1. PyTorch Reference Implementations
3229
# ==========================================
3330

3431

3532
class PytorchLTX2RotaryPosEmbed(torch.nn.Module):
33+
"""
34+
Exact mathematical replica of Diffusers LTX2AudioVideoRotaryPosEmbed.forward
35+
stripped down for testing the core RoPE frequency generation logic.
36+
"""
3637

37-
def __init__(self, dim: int, theta: float = 10000.0):
38+
def __init__(
39+
self, dim: int, theta: float = 10000.0, base_dims=(20, 2048, 2048), rope_type="interleaved", num_attention_heads=32
40+
):
3841
super().__init__()
3942
self.dim = dim
4043
self.theta = theta
44+
self.base_dims = base_dims
45+
self.rope_type = rope_type
46+
self.num_attention_heads = num_attention_heads
47+
self.double_precision = True
4148

4249
def forward(self, ids):
50+
# Test passes ids as [Batch, Sequence, NumAxes]
4351
num_axes = ids.shape[-1]
44-
dim_per_axis = self.dim // num_axes
4552

46-
freq_indices = torch.arange(0, dim_per_axis, 2, dtype=torch.float32)
47-
inv_freq = 1.0 / (self.theta ** (freq_indices / dim_per_axis))
53+
# 1. Scale by max_positions -> [B, S, num_axes]
54+
max_pos = torch.tensor(self.base_dims[:num_axes], dtype=torch.float32, device=ids.device)
55+
grid = ids / max_pos.view(1, 1, num_axes)
56+
57+
# 2. Map to [-1, 1]
58+
scaled_grid = grid * 2.0 - 1.0
4859

49-
freqs_list = []
50-
for i in range(num_axes):
51-
axis_pos = ids[..., i]
52-
freqs = torch.einsum("bs,d->bsd", axis_pos, inv_freq)
53-
freqs_list.append(freqs)
60+
# 3. Base Frequencies
61+
num_rope_elems = num_axes * 2
62+
dim_per_axis = self.dim // num_rope_elems
63+
freqs_dtype = torch.float64 if self.double_precision else torch.float32
64+
pow_indices = torch.pow(
65+
self.theta,
66+
torch.linspace(start=0.0, end=1.0, steps=dim_per_axis, dtype=freqs_dtype, device=ids.device),
67+
)
68+
base_freqs = (pow_indices * (torch.pi / 2.0)).to(dtype=torch.float32) # [steps]
5469

55-
# Concatenate axes -> [B, S, D/2]
56-
emb = torch.cat(freqs_list, dim=-1)
70+
# 4. Outer Product & Transpose (Diffusers specific logic)
71+
# grid: [B, S, num_axes, 1] * base_freqs: [steps] -> [B, S, num_axes, steps]
72+
freqs = scaled_grid.unsqueeze(-1) * base_freqs
73+
# Transpose last two dims: [B, S, steps, num_axes]
74+
freqs = freqs.transpose(-1, -2)
75+
# Flatten: [B, S, steps * num_axes]
76+
emb = freqs.flatten(2)
5777

5878
cos = torch.cos(emb)
5979
sin = torch.sin(emb)
6080

61-
# Interleave: [c1, c2] -> [c1, c1, c2, c2]
62-
cos = torch.repeat_interleave(cos, 2, dim=-1)
63-
sin = torch.repeat_interleave(sin, 2, dim=-1)
81+
if self.rope_type == "interleaved":
82+
# Interleave: [c1, c2] -> [c1, c1, c2, c2]
83+
cos = torch.repeat_interleave(cos, 2, dim=-1)
84+
sin = torch.repeat_interleave(sin, 2, dim=-1)
85+
86+
if self.dim % num_rope_elems != 0:
87+
pad_amt = self.dim - cos.shape[-1]
88+
cos_padding = torch.ones_like(cos[..., :pad_amt])
89+
sin_padding = torch.zeros_like(sin[..., :pad_amt])
90+
cos = torch.cat([cos_padding, cos], dim=-1)
91+
sin = torch.cat([sin_padding, sin], dim=-1)
92+
93+
elif self.rope_type == "split":
94+
pad_size = (self.dim // 2) - cos.shape[-1]
95+
if pad_size > 0:
96+
cos_padding = torch.ones_like(cos[..., :pad_size])
97+
sin_padding = torch.zeros_like(sin[..., :pad_size])
98+
cos = torch.cat([cos_padding, cos], dim=-1)
99+
sin = torch.cat([sin_padding, sin], dim=-1)
100+
101+
b, s, _ = cos.shape
102+
cos = cos.view(b, s, self.num_attention_heads, -1).transpose(1, 2)
103+
sin = sin.view(b, s, self.num_attention_heads, -1).transpose(1, 2)
64104

65-
# Return [B, S, InnerDim] to match JAX/LTX-2 global RoPE
66105
return cos, sin
67106

68107

@@ -138,19 +177,18 @@ def forward(self, x, context=None, q_rope=None, k_rope=None, mask=None):
138177

139178

140179
# ==========================================
141-
# 2. JAX Imports & Test Suite
180+
# 2. JAX Test Suite
142181
# ==========================================
143-
from ..models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
144182

145183

146184
class LTX2AttentionTest(unittest.TestCase):
147185

148186
def setUp(self):
149187
# S=128 is preferred for TPU Flash Attention block sizes
150-
self.B, self.S, self.D = 1, 128, 64
188+
self.B, self.S, self.D = 1, 128, 512
151189
self.heads = 4
152-
self.dim_head = 16
153-
self.context_dim = 64
190+
self.dim_head = 128
191+
self.context_dim = 512
154192

155193
torch.manual_seed(0)
156194
self.rng = nnx.Rngs(0)
@@ -209,15 +247,27 @@ def test_shapes(self):
209247
def test_rope_frequency_parity(self):
210248
dim = 60
211249
rope_pt = PytorchLTX2RotaryPosEmbed(dim=dim)
212-
rope_jax = LTX2RotaryPosEmbed(dim=dim)
250+
rope_pt.double_precision = False
251+
rope_jax = LTX2RotaryPosEmbed(dim=dim, double_precision=False)
213252

214253
np_ids = np.random.randint(0, 100, (2, 16, 3)).astype(np.float32)
254+
255+
# 1. PyTorch Generation and BF16 Cast
215256
pt_cos, pt_sin = rope_pt(torch.from_numpy(np_ids))
257+
pt_cos = pt_cos.to(torch.bfloat16)
258+
pt_sin = pt_sin.to(torch.bfloat16)
259+
260+
# 2. JAX Generation and BF16 Cast
216261
jax_cos, jax_sin = rope_jax(jnp.array(np_ids))
262+
jax_cos = jax_cos.astype(jnp.bfloat16)
263+
jax_sin = jax_sin.astype(jnp.bfloat16)
217264

218-
np.testing.assert_allclose(pt_cos.numpy(), np.array(jax_cos), atol=1e-5)
219-
np.testing.assert_allclose(pt_sin.numpy(), np.array(jax_sin), atol=1e-5)
220-
print("[PASS] RoPE Frequency Parity Verified.")
265+
# Note: Tolerance (3e-2) accounts for JAX XLA fast-math approximations
266+
# combined with the bfloat16 truncation.
267+
# We cast to float32 at the very end because NumPy testing doesn't natively support bfloat16.
268+
np.testing.assert_allclose(pt_cos.float().numpy(), np.array(jax_cos, dtype=np.float32), rtol=0, atol=5e-2)
269+
np.testing.assert_allclose(pt_sin.float().numpy(), np.array(jax_sin, dtype=np.float32), rtol=0, atol=5e-2)
270+
print("[PASS] RoPE Frequency Parity (BF16) Verified.")
221271

222272
def test_parity_bf16_strict(self):
223273
pt_model, jax_model = self._init_and_sync_models(dtype=jnp.bfloat16)
@@ -292,7 +342,8 @@ def test_cross_attn_rope_integration(self):
292342
np_x = np.random.randn(self.B, S_Q, self.D).astype(np.float32)
293343
np_ctx = np.random.randn(self.B, S_KV, self.D).astype(np.float32)
294344

295-
rope_gen_pt = PytorchLTX2RotaryPosEmbed(dim=64) # Gen [B, S, InnerDim]
345+
inner_dim = self.heads * self.dim_head
346+
rope_gen_pt = PytorchLTX2RotaryPosEmbed(dim=inner_dim) # Gen [B, S, InnerDim]
296347

297348
ids_q = torch.randint(0, 100, (self.B, S_Q, 1))
298349
ids_k = torch.randint(0, 100, (self.B, S_KV, 1))
@@ -314,7 +365,7 @@ def test_cross_attn_rope_integration(self):
314365

315366
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
316367
print(f"\n[Cross-Attn + RoPE] Max Diff: {diff:.6f}")
317-
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
368+
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=5e-3)
318369
print("[PASS] Cross-Attention with RoPE Parity Verified.")
319370

320371
def test_attention_mask_parity(self):
@@ -327,16 +378,6 @@ def test_attention_mask_parity(self):
327378

328379
jax_model.attention_op.attention_kernel = "flash"
329380
jax_model.attention_op.mesh = mesh
330-
jax_model.attention_op.flash_block_sizes = splash_attention_kernel.BlockSizes(
331-
block_q=128,
332-
block_kv_compute=128,
333-
block_kv=128,
334-
block_q_dkv=128,
335-
block_kv_dkv=128,
336-
block_kv_dkv_compute=128,
337-
block_q_dq=128,
338-
block_kv_dq=128,
339-
)
340381

341382
mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32)
342383
pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]
@@ -350,7 +391,7 @@ def test_attention_mask_parity(self):
350391

351392
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
352393
print(f"\n[Mask Parity] Max Diff (Flash): {diff:.6f}")
353-
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-4)
394+
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=5e-3)
354395
print("[PASS] Attention Mask Parity Verified.")
355396

356397

0 commit comments

Comments
 (0)