Skip to content

Commit 3a89142

Browse files
committed
feat: add LTX2 smoke test with flash block sizes, enable FSDP, and fix pipeline state sharding
1 parent c98002f commit 3a89142

3 files changed

Lines changed: 126 additions & 16 deletions

File tree

setup.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
# Enable "exit immediately if any command fails" option
2323
set -e
2424
export DEBIAN_FRONTEND=noninteractive
25+
export PIP_INDEX_URL=https://pypi.org/simple
26+
export UV_INDEX_URL=https://pypi.org/simple
2527

2628
echo "Checking Python version..."
2729
# This command will fail if the Python version is less than 3.12
@@ -106,8 +108,13 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
106108
exit 1
107109
fi
108110

109-
# Set uv to use system python by default
110-
export UV_SYSTEM_PYTHON=1
111+
# Set uv to use system python if not in a virtual environment
112+
if python3 -c 'import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)'; then
113+
echo "Virtual environment detected. UV will use it."
114+
else
115+
echo "System Python detected. Setting UV_SYSTEM_PYTHON=1."
116+
export UV_SYSTEM_PYTHON=1
117+
fi
111118

112119
# Install dependencies from requirements.txt first
113120
python3 -m uv pip install -U --resolution=lowest \

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
170170
for path, val in flax.traverse_util.flatten_dict(params).items():
171171
if restored_checkpoint:
172172
path = path[:-1]
173-
sharding = logical_state_sharding[path].value
174-
state[path].value = device_put_replicated(val, sharding)
173+
sharding = logical_state_sharding[path].get_value()
174+
state[path].set_value(device_put_replicated(val, sharding))
175175
state = nnx.from_flat_state(state)
176176

177177
transformer = nnx.merge(graphdef, state, rest_of_state)
@@ -351,10 +351,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
351351
for path, val in flax.traverse_util.flatten_dict(params).items():
352352
sharding = logical_state_sharding.get(path)
353353
if sharding is not None:
354-
sharding = sharding.value
355-
state[path].value = device_put_replicated(val, sharding)
354+
sharding = sharding.get_value()
355+
state[path].set_value(device_put_replicated(val, sharding))
356356
else:
357-
state[path].value = jax.device_put(val)
357+
state[path].set_value(jax.device_put(val))
358358

359359
state = nnx.from_flat_state(state)
360360
connectors = nnx.merge(graphdef, state, rest_of_state)
@@ -393,16 +393,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
393393
for path, val in flax.traverse_util.flatten_dict(params).items():
394394
sharding = logical_state_sharding.get(path)
395395
if sharding is not None:
396-
sharding = sharding.value
396+
sharding = sharding.get_value()
397397
try:
398398
replicate_vae = config.replicate_vae
399399
except ValueError:
400400
replicate_vae = False
401401
if replicate_vae:
402402
sharding = NamedSharding(mesh, P())
403-
state[path].value = device_put_replicated(val, sharding)
403+
state[path].set_value(device_put_replicated(val, sharding))
404404
else:
405-
state[path].value = jax.device_put(val)
405+
state[path].set_value(jax.device_put(val))
406406

407407
state = nnx.from_flat_state(state)
408408
vae = nnx.merge(graphdef, state, rest_of_state)
@@ -441,16 +441,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
441441
for path, val in flax.traverse_util.flatten_dict(params).items():
442442
sharding = logical_state_sharding.get(path)
443443
if sharding is not None:
444-
sharding = sharding.value
444+
sharding = sharding.get_value()
445445
try:
446446
replicate_vae = config.replicate_vae
447447
except ValueError:
448448
replicate_vae = False
449449
if replicate_vae:
450450
sharding = NamedSharding(mesh, P())
451-
state[path].value = device_put_replicated(val, sharding)
451+
state[path].set_value(device_put_replicated(val, sharding))
452452
else:
453-
state[path].value = jax.device_put(val)
453+
state[path].set_value(jax.device_put(val))
454454

455455
state = nnx.from_flat_state(state)
456456
audio_vae = nnx.merge(graphdef, state, rest_of_state)
@@ -510,10 +510,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
510510
for path, val in flax.traverse_util.flatten_dict(params).items():
511511
sharding = logical_state_sharding.get(path)
512512
if sharding is not None:
513-
sharding = sharding.value
514-
state[path].value = device_put_replicated(val, sharding)
513+
sharding = sharding.get_value()
514+
state[path].set_value(device_put_replicated(val, sharding))
515515
else:
516-
state[path].value = jax.device_put(val)
516+
state[path].set_value(jax.device_put(val))
517517

518518
state = nnx.from_flat_state(state)
519519
vocoder = nnx.merge(graphdef, state, rest_of_state)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import time
19+
import unittest
20+
import jax
21+
import jax.numpy as jnp
22+
23+
from maxdiffusion import pyconfig
24+
from maxdiffusion.checkpointing.ltx2_checkpointer import LTX2Checkpointer
25+
26+
try:
27+
jax.distributed.initialize()
28+
except Exception:
29+
pass
30+
31+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
32+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
33+
34+
35+
class LTX2SmokeTest(unittest.TestCase):
36+
"""End-to-end smoke test for LTX2."""
37+
38+
@classmethod
39+
def setUpClass(cls):
40+
# Initialize config with the LTX2 video config file
41+
pyconfig.initialize(
42+
[
43+
None,
44+
os.path.join(THIS_DIR, "..", "configs", "ltx2_video.yml"),
45+
"num_inference_steps=2", # Small number of steps for fast test
46+
"height=256", # Small resolution
47+
"width=256",
48+
"num_frames=9", # Small number of frames
49+
"seed=0",
50+
"attention=flash",
51+
"ici_fsdp_parallelism=-1",
52+
"ici_data_parallelism=1",
53+
"ici_context_parallelism=1",
54+
"ici_tensor_parallelism=1",
55+
'flash_block_sizes={"block_q":512,"block_kv":512,"block_kv_compute":512,"block_q_dkv":512,"block_kv_dkv":512,"block_kv_dkv_compute":512,"use_fused_bwd_kernel":true}',
56+
],
57+
unittest=True,
58+
)
59+
cls.config = pyconfig.config
60+
checkpoint_loader = LTX2Checkpointer(config=cls.config)
61+
# Load pipeline without upsampler for simplicity in smoke test
62+
cls.pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=False)
63+
64+
cls.prompt = [cls.config.prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
65+
cls.negative_prompt = [cls.config.negative_prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
66+
67+
def test_ltx2_inference(self):
68+
"""Test that LTX2 pipeline can run inference and produce output."""
69+
generator = jax.random.key(self.config.seed)
70+
71+
t0 = time.perf_counter()
72+
out = self.pipeline(
73+
prompt=self.prompt,
74+
negative_prompt=self.negative_prompt,
75+
height=self.config.height,
76+
width=self.config.width,
77+
num_frames=self.config.num_frames,
78+
num_inference_steps=self.config.num_inference_steps,
79+
guidance_scale=self.config.guidance_scale,
80+
generator=generator,
81+
dtype=jnp.bfloat16,
82+
)
83+
t1 = time.perf_counter()
84+
85+
print(f"LTX2 Inference took: {t1 - t0:.2f}s")
86+
87+
videos = out.frames if hasattr(out, "frames") else out[0]
88+
audios = out.audio if hasattr(out, "audio") else None
89+
90+
self.assertIsNotNone(videos)
91+
# Check that we got frames
92+
self.assertGreater(len(videos), 0)
93+
94+
# LTX2 might also produce audio, check if it's there if expected
95+
# The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio.
96+
# We can just log if audio is present.
97+
if audios is not None:
98+
print(f"Audio produced with shape: {audios[0].shape}")
99+
self.assertGreater(len(audios), 0)
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

0 commit comments

Comments
 (0)