diff --git a/setup.sh b/setup.sh index 45c952741..3f1141888 100644 --- a/setup.sh +++ b/setup.sh @@ -22,6 +22,8 @@ # Enable "exit immediately if any command fails" option set -e export DEBIAN_FRONTEND=noninteractive +export PIP_INDEX_URL=https://pypi.org/simple +export UV_INDEX_URL=https://pypi.org/simple echo "Checking Python version..." # This command will fail if the Python version is less than 3.12 @@ -106,8 +108,13 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then exit 1 fi -# Set uv to use system python by default -export UV_SYSTEM_PYTHON=1 +# Set uv to use system python if not in a virtual environment +if python3 -c 'import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)'; then + echo "Virtual environment detected. UV will use it." +else + echo "System Python detected. Setting UV_SYSTEM_PYTHON=1." + export UV_SYSTEM_PYTHON=1 +fi # Install dependencies from requirements.txt first python3 -m uv pip install -U --resolution=lowest \ diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index a3ec9591c..9cc1c970e 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -170,8 +170,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict): for path, val in flax.traverse_util.flatten_dict(params).items(): if restored_checkpoint: path = path[:-1] - sharding = logical_state_sharding[path].value - state[path].value = device_put_replicated(val, sharding) + sharding = logical_state_sharding[path].get_value() + state[path].set_value(device_put_replicated(val, sharding)) state = nnx.from_flat_state(state) transformer = nnx.merge(graphdef, state, rest_of_state) @@ -351,10 +351,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding.get(path) if sharding is not None: - sharding = sharding.value - state[path].value = device_put_replicated(val, sharding) + sharding = sharding.get_value() + state[path].set_value(device_put_replicated(val, sharding)) else: - state[path].value = jax.device_put(val) + state[path].set_value(jax.device_put(val)) state = nnx.from_flat_state(state) connectors = nnx.merge(graphdef, state, rest_of_state) @@ -393,16 +393,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding.get(path) if sharding is not None: - sharding = sharding.value + sharding = sharding.get_value() try: replicate_vae = config.replicate_vae except ValueError: replicate_vae = False if replicate_vae: sharding = NamedSharding(mesh, P()) - state[path].value = device_put_replicated(val, sharding) + state[path].set_value(device_put_replicated(val, sharding)) else: - state[path].value = jax.device_put(val) + state[path].set_value(jax.device_put(val)) state = nnx.from_flat_state(state) vae = nnx.merge(graphdef, state, rest_of_state) @@ -441,16 +441,16 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding.get(path) if sharding is not None: - sharding = sharding.value + sharding = sharding.get_value() try: replicate_vae = config.replicate_vae except ValueError: replicate_vae = False if replicate_vae: sharding = NamedSharding(mesh, P()) - state[path].value = device_put_replicated(val, sharding) + state[path].set_value(device_put_replicated(val, sharding)) else: - state[path].value = jax.device_put(val) + state[path].set_value(jax.device_put(val)) state = nnx.from_flat_state(state) audio_vae = nnx.merge(graphdef, state, rest_of_state) @@ -510,10 +510,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding.get(path) if sharding is not None: - sharding = sharding.value - state[path].value = device_put_replicated(val, sharding) + sharding = sharding.get_value() + state[path].set_value(device_put_replicated(val, sharding)) else: - state[path].value = jax.device_put(val) + state[path].set_value(jax.device_put(val)) state = nnx.from_flat_state(state) vocoder = nnx.merge(graphdef, state, rest_of_state) diff --git a/src/maxdiffusion/tests/generate_ltx2_smoke_test.py b/src/maxdiffusion/tests/generate_ltx2_smoke_test.py new file mode 100644 index 000000000..6d0bd0f34 --- /dev/null +++ b/src/maxdiffusion/tests/generate_ltx2_smoke_test.py @@ -0,0 +1,109 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import time +import unittest +import jax +import jax.numpy as jnp + +from maxdiffusion import pyconfig +from maxdiffusion.checkpointing.ltx2_checkpointer import LTX2Checkpointer + +try: + jax.distributed.initialize() +except Exception: + pass + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class LTX2SmokeTest(unittest.TestCase): + """End-to-end smoke test for LTX2.""" + + @classmethod + def setUpClass(cls): + # Initialize config with the LTX2 video config file + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "ltx2_video.yml"), + "num_inference_steps=2", # Small number of steps for fast test + "height=256", # Small resolution + "width=256", + "num_frames=9", # Small number of frames + "seed=0", + "attention=flash", + "ici_fsdp_parallelism=1", + "ici_data_parallelism=1", + "ici_context_parallelism=1", + "ici_tensor_parallelism=-1", + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = LTX2Checkpointer(config=cls.config) + # Load pipeline without upsampler for simplicity in smoke test + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=False) + + cls.prompt = [cls.config.prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1) + cls.negative_prompt = [cls.config.negative_prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1) + + def test_ltx2_inference(self): + """Test that LTX2 pipeline can run inference and produce output.""" + generator = jax.random.key(self.config.seed) + + t0 = time.perf_counter() + out = self.pipeline( + prompt=self.prompt, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + guidance_scale=self.config.guidance_scale, + generator=generator, + dtype=jnp.bfloat16, + ) + t1 = time.perf_counter() + + print(f"LTX2 Inference took: {t1 - t0:.2f}s") + + videos = out.frames if hasattr(out, "frames") else out[0] + audios = out.audio if hasattr(out, "audio") else None + + self.assertIsNotNone(videos) + # Check that we got frames + self.assertGreater(len(videos), 0) + + # LTX2 might also produce audio, check if it's there if expected + # The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio. + # We can just log if audio is present. + if audios is not None: + print(f"Audio produced with shape: {audios[0].shape}") + self.assertGreater(len(audios), 0) + + @classmethod + def tearDownClass(cls): + del cls.pipeline + import gc + + gc.collect() + + +if __name__ == "__main__": + unittest.main()