Skip to content

Commit 1fd5338

Browse files
committed
modifying test for testing rope = split
1 parent ba15412 commit 1fd5338

2 files changed

Lines changed: 4 additions & 45 deletions

File tree

src/maxdiffusion/tests/ltx2_parity_test.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def test_transformer_block_shapes(self):
8686
audio_attention_head_dim=128,
8787
audio_cross_attention_dim=cross_dim,
8888
activation_fn="gelu",
89-
qk_norm="rms_norm_across_heads",
90-
qk_norm="rms_norm_across_heads",
9189
mesh=self.mesh,
9290
)
9391

@@ -194,32 +192,6 @@ def test_transformer_3d_model_instantiation_and_forward(self):
194192
mesh=self.mesh,
195193
)
196194

197-
# Inputs
198-
# hidden_states: (B, F, H, W, C) or (B, L, C)?
199-
# Diffusers `forward` takes `hidden_states` usually as latents.
200-
# If it's 3D, it might expect (B, C, F, H, W) or (B, F, C, H, W)?
201-
# Checking `transformer_ltx2.py` `__call__` Line 680:
202-
# `hidden_states = self.proj_in(hidden_states)`
203-
# `proj_in` is nnx.Linear.
204-
# This implies `hidden_states` input is ALREADY flattened/sequenced or `proj_in` assumes channel-last inputs.
205-
# If `proj_in` is Linear, input must be compatible with matrix mult.
206-
# Usually Transformers expect (B, L, D) or (B, N, D).
207-
# But `prepare_video_coords` logic suggests it handles spatial awareness.
208-
# The PROMPT usually implies `latents` of shape (B, C, F, H, W).
209-
# BUT `nnx.Linear` (Dense) applies to the last dimension.
210-
# If input is (B, C, F, H, W), Linear would act on W. That's wrong.
211-
# Diffusers LTX usually patchifies EXTERNALLY or has a conv patch embed?
212-
# In my definition (Line 491): `self.proj_in = nnx.Linear(...)`.
213-
# This differs from Conv3d.
214-
# This implies the user MUST pass flattened tokens?
215-
# Re-checking Diffusers implementation...
216-
# If `LTX2VideoTransformer3DModel` in Diffusers uses `patch_embed` (Conv), it takes 5D.
217-
# Verify `transformer_ltx2.py` user edits...
218-
# Step 426 (Original) had `nnx.Conv`.
219-
# Step 491 (New) has `nnx.Linear`.
220-
# This suggests input is EXPECTED to be flattened/patchified already OR raw channel-last (B, ..., C).
221-
# IMPORTANT: if `proj_in` is Linear, we pass (B, L, C).
222-
223195
# Let's pass (B, L, C).
224196
hidden_states = jnp.zeros((self.batch_size, self.seq_len, self.in_channels))
225197
audio_hidden_states = jnp.zeros((self.batch_size, 128, self.audio_in_channels))
@@ -352,15 +324,6 @@ def test_scan_remat_parity(self):
352324
model_loop = LTX2VideoTransformer3DModel(**args, scan_layers=False, mesh=self.mesh)
353325
model_remat = LTX2VideoTransformer3DModel(**args, scan_layers=True, remat_policy="full", mesh=self.mesh)
354326

355-
# 2. Sync weights (crucial for parity)
356-
# We can just copy params from scan to loop/remat
357-
# Assuming identical structure, nnx.state(model) should be compatible?
358-
# scan_layers=True uses `nnx.scan` which might change state structure (Scan variable?)
359-
# Actually maxdiffusion `transformer_wan.py` shows they are compatible if variable structure is clean.
360-
# But `nnx.scan` lifts variables into `Scan` collections sometimes?
361-
# Let's try simple state transfer or just basic shape check if transfer fails.
362-
# Ideally we want exact numerical parity.
363-
364327
# Inputs
365328
hidden_states = jnp.ones((self.batch_size, self.seq_len, self.in_channels)) * 0.5
366329
audio_hidden_states = jnp.ones((self.batch_size, 128, self.audio_in_channels)) * 0.5

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,10 @@ def test_ltx2_rope_split(self):
128128
cos, sin = rope(ids)
129129

130130
# Check output shape
131-
# Split RoPE returns concatenated [cos, cos] to match dim
132-
self.assertEqual(cos.shape, (1, 10, dim))
133-
self.assertEqual(sin.shape, (1, 10, dim))
134-
135-
# Verify values are concatenated
136-
cos1, cos2 = jnp.split(cos, 2, axis=-1)
137-
# They should be identical
138-
self.assertTrue(jnp.allclose(cos1, cos2))
131+
# Split RoPE returns [B, H, S, D//2]
132+
# dim=1024, heads=32 => head_dim=32 => D//2 = 16
133+
self.assertEqual(cos.shape, (1, 32, 10, 16))
134+
self.assertEqual(sin.shape, (1, 32, 10, 16))
139135

140136

141137
def test_ltx2_ada_layer_norm_single(self):

0 commit comments

Comments
 (0)