@@ -86,8 +86,6 @@ def test_transformer_block_shapes(self):
8686 audio_attention_head_dim = 128 ,
8787 audio_cross_attention_dim = cross_dim ,
8888 activation_fn = "gelu" ,
89- qk_norm = "rms_norm_across_heads" ,
90- qk_norm = "rms_norm_across_heads" ,
9189 mesh = self .mesh ,
9290 )
9391
@@ -194,32 +192,6 @@ def test_transformer_3d_model_instantiation_and_forward(self):
194192 mesh = self .mesh ,
195193 )
196194
197- # Inputs
198- # hidden_states: (B, F, H, W, C) or (B, L, C)?
199- # Diffusers `forward` takes `hidden_states` usually as latents.
200- # If it's 3D, it might expect (B, C, F, H, W) or (B, F, C, H, W)?
201- # Checking `transformer_ltx2.py` `__call__` Line 680:
202- # `hidden_states = self.proj_in(hidden_states)`
203- # `proj_in` is nnx.Linear.
204- # This implies `hidden_states` input is ALREADY flattened/sequenced or `proj_in` assumes channel-last inputs.
205- # If `proj_in` is Linear, input must be compatible with matrix mult.
206- # Usually Transformers expect (B, L, D) or (B, N, D).
207- # But `prepare_video_coords` logic suggests it handles spatial awareness.
208- # The PROMPT usually implies `latents` of shape (B, C, F, H, W).
209- # BUT `nnx.Linear` (Dense) applies to the last dimension.
210- # If input is (B, C, F, H, W), Linear would act on W. That's wrong.
211- # Diffusers LTX usually patchifies EXTERNALLY or has a conv patch embed?
212- # In my definition (Line 491): `self.proj_in = nnx.Linear(...)`.
213- # This differs from Conv3d.
214- # This implies the user MUST pass flattened tokens?
215- # Re-checking Diffusers implementation...
216- # If `LTX2VideoTransformer3DModel` in Diffusers uses `patch_embed` (Conv), it takes 5D.
217- # Verify `transformer_ltx2.py` user edits...
218- # Step 426 (Original) had `nnx.Conv`.
219- # Step 491 (New) has `nnx.Linear`.
220- # This suggests input is EXPECTED to be flattened/patchified already OR raw channel-last (B, ..., C).
221- # IMPORTANT: if `proj_in` is Linear, we pass (B, L, C).
222-
223195 # Let's pass (B, L, C).
224196 hidden_states = jnp .zeros ((self .batch_size , self .seq_len , self .in_channels ))
225197 audio_hidden_states = jnp .zeros ((self .batch_size , 128 , self .audio_in_channels ))
@@ -352,15 +324,6 @@ def test_scan_remat_parity(self):
352324 model_loop = LTX2VideoTransformer3DModel (** args , scan_layers = False , mesh = self .mesh )
353325 model_remat = LTX2VideoTransformer3DModel (** args , scan_layers = True , remat_policy = "full" , mesh = self .mesh )
354326
355- # 2. Sync weights (crucial for parity)
356- # We can just copy params from scan to loop/remat
357- # Assuming identical structure, nnx.state(model) should be compatible?
358- # scan_layers=True uses `nnx.scan` which might change state structure (Scan variable?)
359- # Actually maxdiffusion `transformer_wan.py` shows they are compatible if variable structure is clean.
360- # But `nnx.scan` lifts variables into `Scan` collections sometimes?
361- # Let's try simple state transfer or just basic shape check if transfer fails.
362- # Ideally we want exact numerical parity.
363-
364327 # Inputs
365328 hidden_states = jnp .ones ((self .batch_size , self .seq_len , self .in_channels )) * 0.5
366329 audio_hidden_states = jnp .ones ((self .batch_size , 128 , self .audio_in_channels )) * 0.5
0 commit comments