Skip to content

Commit a064fbe

Browse files
committed
fixes in ltx_transformer_step_test.py
1 parent 4cbc19c commit a064fbe

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxdiffusion/tests/ltx_transformer_step_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def test_one_step_transformer(self):
153153
state_shardings["transformer"] = transformer_state_shardings
154154
states["transformer"] = transformer_state
155155
example_inputs = {}
156-
batch_size, num_tokens = 4, 256
156+
batch_size, num_tokens = 8, 256
157157
input_shapes = {
158158
"latents": (batch_size, num_tokens, in_channels),
159159
"fractional_coords": (batch_size, 3, num_tokens),
@@ -194,6 +194,10 @@ def test_one_step_transformer(self):
194194

195195
noise_pred = p_run_inference(states).block_until_ready()
196196
noise_pred = torch.from_numpy(np.array(noise_pred))
197+
198+
# Using batch_size=8 to satisfy fsdp=8 sharding constraint, but reference is batch_size=4
199+
if batch_size != noise_pred_pt.shape[0]:
200+
noise_pred = noise_pred[:noise_pred_pt.shape[0]]
197201

198202
torch.testing.assert_close(noise_pred_pt, noise_pred, atol=0.025, rtol=20)
199203

0 commit comments

Comments
 (0)