Skip to content

Commit 9575ff5

Browse files
committed
feat: add LTX2 smoke test and update setup.sh environment
1 parent a03d34b commit 9575ff5

1 file changed

Lines changed: 10 additions & 10 deletions

File tree

src/maxdiffusion/tests/ltx2_smoke_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import os
1818
import time
1919
import unittest
20-
import pytest
2120
import jax
2221
import jax.numpy as jnp
2322

@@ -32,6 +31,7 @@
3231
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
3332
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
3433

34+
3535
class LTX2SmokeTest(unittest.TestCase):
3636
"""End-to-end smoke test for LTX2."""
3737

@@ -43,9 +43,9 @@ def setUpClass(cls):
4343
None,
4444
os.path.join(THIS_DIR, "..", "configs", "ltx2_video.yml"),
4545
"num_inference_steps=2", # Small number of steps for fast test
46-
"height=256", # Small resolution
46+
"height=256", # Small resolution
4747
"width=256",
48-
"num_frames=9", # Small number of frames
48+
"num_frames=9", # Small number of frames
4949
"seed=0",
5050
"attention=flash",
5151
],
@@ -62,7 +62,7 @@ def setUpClass(cls):
6262
def test_ltx2_inference(self):
6363
"""Test that LTX2 pipeline can run inference and produce output."""
6464
generator = jax.random.key(self.config.seed)
65-
65+
6666
t0 = time.perf_counter()
6767
out = self.pipeline(
6868
prompt=self.prompt,
@@ -76,22 +76,22 @@ def test_ltx2_inference(self):
7676
dtype=jnp.bfloat16,
7777
)
7878
t1 = time.perf_counter()
79-
79+
8080
print(f"LTX2 Inference took: {t1 - t0:.2f}s")
81-
81+
8282
videos = out.frames if hasattr(out, "frames") else out[0]
8383
audios = out.audio if hasattr(out, "audio") else None
84-
84+
8585
self.assertIsNotNone(videos)
8686
# Check that we got frames
8787
self.assertGreater(len(videos), 0)
88-
88+
8989
# LTX2 might also produce audio, check if it's there if expected
9090
# The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio.
9191
# We can just log if audio is present.
9292
if audios is not None:
93-
print(f"Audio produced with shape: {audios[0].shape}")
94-
self.assertGreater(len(audios), 0)
93+
print(f"Audio produced with shape: {audios[0].shape}")
94+
self.assertGreater(len(audios), 0)
9595

9696
if __name__ == "__main__":
9797
unittest.main()

0 commit comments

Comments
 (0)