Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
# Enable "exit immediately if any command fails" option
set -e
export DEBIAN_FRONTEND=noninteractive
export PIP_INDEX_URL=https://pypi.org/simple
export UV_INDEX_URL=https://pypi.org/simple

echo "Checking Python version..."
# This command will fail if the Python version is less than 3.12
Expand Down Expand Up @@ -106,8 +108,13 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
exit 1
fi

# Set uv to use system python by default
export UV_SYSTEM_PYTHON=1
# Set uv to use system python if not in a virtual environment
if python3 -c 'import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)'; then
echo "Virtual environment detected. UV will use it."
else
echo "System Python detected. Setting UV_SYSTEM_PYTHON=1."
export UV_SYSTEM_PYTHON=1
fi

# Install dependencies from requirements.txt first
python3 -m uv pip install -U --resolution=lowest \
Expand Down
28 changes: 14 additions & 14 deletions src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
for path, val in flax.traverse_util.flatten_dict(params).items():
if restored_checkpoint:
path = path[:-1]
sharding = logical_state_sharding[path].value
state[path].value = device_put_replicated(val, sharding)
sharding = logical_state_sharding[path].get_value()
state[path].set_value(device_put_replicated(val, sharding))
state = nnx.from_flat_state(state)

transformer = nnx.merge(graphdef, state, rest_of_state)
Expand Down Expand Up @@ -351,10 +351,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding.get(path)
if sharding is not None:
sharding = sharding.value
state[path].value = device_put_replicated(val, sharding)
sharding = sharding.get_value()
state[path].set_value(device_put_replicated(val, sharding))
else:
state[path].value = jax.device_put(val)
state[path].set_value(jax.device_put(val))

state = nnx.from_flat_state(state)
connectors = nnx.merge(graphdef, state, rest_of_state)
Expand Down Expand Up @@ -393,16 +393,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding.get(path)
if sharding is not None:
sharding = sharding.value
sharding = sharding.get_value()
try:
replicate_vae = config.replicate_vae
except ValueError:
replicate_vae = False
if replicate_vae:
sharding = NamedSharding(mesh, P())
state[path].value = device_put_replicated(val, sharding)
state[path].set_value(device_put_replicated(val, sharding))
else:
state[path].value = jax.device_put(val)
state[path].set_value(jax.device_put(val))

state = nnx.from_flat_state(state)
vae = nnx.merge(graphdef, state, rest_of_state)
Expand Down Expand Up @@ -441,16 +441,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding.get(path)
if sharding is not None:
sharding = sharding.value
sharding = sharding.get_value()
try:
replicate_vae = config.replicate_vae
except ValueError:
replicate_vae = False
if replicate_vae:
sharding = NamedSharding(mesh, P())
state[path].value = device_put_replicated(val, sharding)
state[path].set_value(device_put_replicated(val, sharding))
else:
state[path].value = jax.device_put(val)
state[path].set_value(jax.device_put(val))

state = nnx.from_flat_state(state)
audio_vae = nnx.merge(graphdef, state, rest_of_state)
Expand Down Expand Up @@ -510,10 +510,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding.get(path)
if sharding is not None:
sharding = sharding.value
state[path].value = device_put_replicated(val, sharding)
sharding = sharding.get_value()
state[path].set_value(device_put_replicated(val, sharding))
else:
state[path].value = jax.device_put(val)
state[path].set_value(jax.device_put(val))

state = nnx.from_flat_state(state)
vocoder = nnx.merge(graphdef, state, rest_of_state)
Expand Down
109 changes: 109 additions & 0 deletions src/maxdiffusion/tests/generate_ltx2_smoke_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright 2026 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import time
import unittest
import jax
import jax.numpy as jnp

from maxdiffusion import pyconfig
from maxdiffusion.checkpointing.ltx2_checkpointer import LTX2Checkpointer

try:
jax.distributed.initialize()
except Exception:
pass

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
THIS_DIR = os.path.dirname(os.path.abspath(__file__))


class LTX2SmokeTest(unittest.TestCase):
"""End-to-end smoke test for LTX2."""

@classmethod
def setUpClass(cls):
# Initialize config with the LTX2 video config file
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "ltx2_video.yml"),
"num_inference_steps=2", # Small number of steps for fast test
"height=256", # Small resolution
"width=256",
"num_frames=9", # Small number of frames
"seed=0",
"attention=flash",
"ici_fsdp_parallelism=1",
"ici_data_parallelism=1",
"ici_context_parallelism=1",
"ici_tensor_parallelism=-1",
],
unittest=True,
)
cls.config = pyconfig.config
checkpoint_loader = LTX2Checkpointer(config=cls.config)
# Load pipeline without upsampler for simplicity in smoke test
cls.pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=False)

cls.prompt = [cls.config.prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
cls.negative_prompt = [cls.config.negative_prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)

def test_ltx2_inference(self):
"""Test that LTX2 pipeline can run inference and produce output."""
generator = jax.random.key(self.config.seed)

t0 = time.perf_counter()
out = self.pipeline(
prompt=self.prompt,
negative_prompt=self.negative_prompt,
height=self.config.height,
width=self.config.width,
num_frames=self.config.num_frames,
num_inference_steps=self.config.num_inference_steps,
guidance_scale=self.config.guidance_scale,
generator=generator,
dtype=jnp.bfloat16,
)
t1 = time.perf_counter()

print(f"LTX2 Inference took: {t1 - t0:.2f}s")

videos = out.frames if hasattr(out, "frames") else out[0]
audios = out.audio if hasattr(out, "audio") else None

self.assertIsNotNone(videos)
# Check that we got frames
self.assertGreater(len(videos), 0)

# LTX2 might also produce audio, check if it's there if expected
# The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio.
# We can just log if audio is present.
if audios is not None:
print(f"Audio produced with shape: {audios[0].shape}")
self.assertGreater(len(audios), 0)

@classmethod
def tearDownClass(cls):
del cls.pipeline
import gc

gc.collect()


if __name__ == "__main__":
unittest.main()
Loading