Skip to content

Commit a03d34b

Browse files
committed
feat: add LTX2 smoke test and update setup.sh environment
1 parent c98002f commit a03d34b

2 files changed

Lines changed: 106 additions & 2 deletions

File tree

setup.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
# Enable "exit immediately if any command fails" option
2323
set -e
2424
export DEBIAN_FRONTEND=noninteractive
25+
export PIP_INDEX_URL=https://pypi.org/simple
26+
export UV_INDEX_URL=https://pypi.org/simple
2527

2628
echo "Checking Python version..."
2729
# This command will fail if the Python version is less than 3.12
@@ -106,8 +108,13 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
106108
exit 1
107109
fi
108110

109-
# Set uv to use system python by default
110-
export UV_SYSTEM_PYTHON=1
111+
# Set uv to use system python if not in a virtual environment
112+
if python3 -c 'import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)'; then
113+
echo "Virtual environment detected. UV will use it."
114+
else
115+
echo "System Python detected. Setting UV_SYSTEM_PYTHON=1."
116+
export UV_SYSTEM_PYTHON=1
117+
fi
111118

112119
# Install dependencies from requirements.txt first
113120
python3 -m uv pip install -U --resolution=lowest \
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import time
19+
import unittest
20+
import pytest
21+
import jax
22+
import jax.numpy as jnp
23+
24+
from maxdiffusion import pyconfig
25+
from maxdiffusion.checkpointing.ltx2_checkpointer import LTX2Checkpointer
26+
27+
try:
28+
jax.distributed.initialize()
29+
except Exception:
30+
pass
31+
32+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
33+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
34+
35+
class LTX2SmokeTest(unittest.TestCase):
36+
"""End-to-end smoke test for LTX2."""
37+
38+
@classmethod
39+
def setUpClass(cls):
40+
# Initialize config with the LTX2 video config file
41+
pyconfig.initialize(
42+
[
43+
None,
44+
os.path.join(THIS_DIR, "..", "configs", "ltx2_video.yml"),
45+
"num_inference_steps=2", # Small number of steps for fast test
46+
"height=256", # Small resolution
47+
"width=256",
48+
"num_frames=9", # Small number of frames
49+
"seed=0",
50+
"attention=flash",
51+
],
52+
unittest=True,
53+
)
54+
cls.config = pyconfig.config
55+
checkpoint_loader = LTX2Checkpointer(config=cls.config)
56+
# Load pipeline without upsampler for simplicity in smoke test
57+
cls.pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=False)
58+
59+
cls.prompt = [cls.config.prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
60+
cls.negative_prompt = [cls.config.negative_prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
61+
62+
def test_ltx2_inference(self):
63+
"""Test that LTX2 pipeline can run inference and produce output."""
64+
generator = jax.random.key(self.config.seed)
65+
66+
t0 = time.perf_counter()
67+
out = self.pipeline(
68+
prompt=self.prompt,
69+
negative_prompt=self.negative_prompt,
70+
height=self.config.height,
71+
width=self.config.width,
72+
num_frames=self.config.num_frames,
73+
num_inference_steps=self.config.num_inference_steps,
74+
guidance_scale=self.config.guidance_scale,
75+
generator=generator,
76+
dtype=jnp.bfloat16,
77+
)
78+
t1 = time.perf_counter()
79+
80+
print(f"LTX2 Inference took: {t1 - t0:.2f}s")
81+
82+
videos = out.frames if hasattr(out, "frames") else out[0]
83+
audios = out.audio if hasattr(out, "audio") else None
84+
85+
self.assertIsNotNone(videos)
86+
# Check that we got frames
87+
self.assertGreater(len(videos), 0)
88+
89+
# LTX2 might also produce audio, check if it's there if expected
90+
# The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio.
91+
# We can just log if audio is present.
92+
if audios is not None:
93+
print(f"Audio produced with shape: {audios[0].shape}")
94+
self.assertGreater(len(audios), 0)
95+
96+
if __name__ == "__main__":
97+
unittest.main()

0 commit comments

Comments
 (0)