Skip to content

Commit fcb5074

Browse files
committed
backward compatibility for ltx2
1 parent 6d38ae2 commit fcb5074

3 files changed

Lines changed: 16 additions & 3 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ height: 512
3636
width: 768
3737
decode_timestep: 0.05
3838
decode_noise_scale: 0.025
39+
# Matches historical MaxDiffusion LTX-2 default when using generate_ltx2 (Diffusers uses 0.0).
40+
noise_scale: 1.0
3941
num_frames: 121
4042
quantization: "int8"
4143
seed: 10

src/maxdiffusion/generate_ltx2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import subprocess
2121
from maxdiffusion.checkpointing.ltx2_checkpointer import LTX2Checkpointer
2222
from maxdiffusion import pyconfig, max_logging, max_utils
23+
from maxdiffusion.common_types import LTX2_3
2324
from absl import app
2425
from google.cloud import storage
2526
from google.api_core.exceptions import GoogleAPIError
@@ -85,6 +86,12 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
8586
generator = jax.random.key(config.seed) if hasattr(config, "seed") else jax.random.key(0)
8687
guidance_scale = config.guidance_scale if hasattr(config, "guidance_scale") else 3.0
8788

89+
_noise_missing = object()
90+
_noise = getattr(config, "noise_scale", _noise_missing)
91+
if _noise is _noise_missing or _noise is None:
92+
# Legacy LTX-2 default; LTX-2.3 aligns with Diffusers (0.0). YAML may override either.
93+
_noise = 0.0 if getattr(config, "model_name", "") == LTX2_3 else 1.0
94+
8895
out = pipeline(
8996
prompt=prompt,
9097
negative_prompt=negative_prompt,
@@ -106,7 +113,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
106113
modality_scale=getattr(config, "modality_scale", 1.0),
107114
audio_modality_scale=getattr(config, "audio_modality_scale", None),
108115
use_cross_timestep=getattr(config, "use_cross_timestep", None),
109-
noise_scale=getattr(config, "noise_scale", 0.0),
116+
noise_scale=_noise,
110117
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
111118
)
112119
return out

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ def __call__(
12601260
audio_guidance_rescale: Optional[float] = None,
12611261
audio_stg_scale: Optional[float] = None,
12621262
audio_modality_scale: Optional[float] = None,
1263-
noise_scale: float = 0.0,
1263+
noise_scale: float = 1.0,
12641264
num_videos_per_prompt: Optional[int] = 1,
12651265
generator: Optional[jax.Array] = None,
12661266
latents: Optional[jax.Array] = None,
@@ -1299,7 +1299,11 @@ def __call__(
12991299

13001300
do_cfg = (guidance_scale > 1.0) or (audio_guidance_scale > 1.0)
13011301
do_stg_effective = (stg_scale > 0.0) or (audio_stg_scale > 0.0)
1302-
do_modality_effective = (modality_scale > 1.0) or (audio_modality_scale > 1.0)
1302+
# Modality-isolation fused stacks match Diffusers LTX-2.3; LTX-2.0 weights/config ignore extra modality rows.
1303+
model_is_ltx2_3 = getattr(self.config, "model_name", "") == "ltx2.3"
1304+
do_modality_effective = model_is_ltx2_3 and (
1305+
(modality_scale > 1.0) or (audio_modality_scale > 1.0)
1306+
)
13031307

13041308
# 2. Encode inputs (Text)
13051309
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(

0 commit comments

Comments
 (0)