Skip to content

Commit bf37ac2

Browse files
committed
passed mesh param to test for transformer
1 parent 755e1b2 commit bf37ac2

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_transformer_block_shapes(self):
8686
audio_cross_attention_dim=cross_dim,
8787
activation_fn="gelu",
8888
qk_norm="rms_norm_across_heads",
89+
mesh=self.mesh,
8990
)
9091

9192
# Create dummy inputs
@@ -188,7 +189,8 @@ def test_transformer_3d_model_instantiation_and_forward(self):
188189
audio_out_channels= self.audio_in_channels,
189190
audio_num_attention_heads=2,
190191
audio_attention_head_dim=16,
191-
audio_cross_attention_dim=32
192+
audio_cross_attention_dim=32,
193+
mesh=self.mesh,
192194
)
193195

194196
# Inputs
@@ -283,9 +285,9 @@ def test_scan_remat_parity(self):
283285

284286
# 1. Initialize models
285287
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
286-
model_scan = LTX2VideoTransformer3DModel(**args, scan_layers=True)
287-
model_loop = LTX2VideoTransformer3DModel(**args, scan_layers=False)
288-
model_remat = LTX2VideoTransformer3DModel(**args, scan_layers=True, remat_policy="full")
288+
model_scan = LTX2VideoTransformer3DModel(**args, scan_layers=True, mesh=self.mesh)
289+
model_loop = LTX2VideoTransformer3DModel(**args, scan_layers=False, mesh=self.mesh)
290+
model_remat = LTX2VideoTransformer3DModel(**args, scan_layers=True, remat_policy="full", mesh=self.mesh)
289291

290292
# 2. Sync weights (crucial for parity)
291293
# We can just copy params from scan to loop/remat

0 commit comments

Comments
 (0)