2020import subprocess
2121from maxdiffusion .checkpointing .ltx2_checkpointer import LTX2Checkpointer
2222from maxdiffusion import pyconfig , max_logging , max_utils
23+ from maxdiffusion .common_types import LTX2_3
2324from absl import app
2425from google .cloud import storage
2526from google .api_core .exceptions import GoogleAPIError
@@ -85,6 +86,12 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
8586 generator = jax .random .key (config .seed ) if hasattr (config , "seed" ) else jax .random .key (0 )
8687 guidance_scale = config .guidance_scale if hasattr (config , "guidance_scale" ) else 3.0
8788
89+ _noise_missing = object ()
90+ _noise = getattr (config , "noise_scale" , _noise_missing )
91+ if _noise is _noise_missing or _noise is None :
92+ # Legacy LTX-2 default; LTX-2.3 aligns with Diffusers (0.0). YAML may override either.
93+ _noise = 0.0 if getattr (config , "model_name" , "" ) == LTX2_3 else 1.0
94+
8895 out = pipeline (
8996 prompt = prompt ,
9097 negative_prompt = negative_prompt ,
@@ -106,7 +113,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
106113 modality_scale = getattr (config , "modality_scale" , 1.0 ),
107114 audio_modality_scale = getattr (config , "audio_modality_scale" , None ),
108115 use_cross_timestep = getattr (config , "use_cross_timestep" , None ),
109- noise_scale = getattr ( config , "noise_scale" , 0.0 ) ,
116+ noise_scale = _noise ,
110117 dtype = jnp .bfloat16 if getattr (config , "activations_dtype" , "bfloat16" ) == "bfloat16" else jnp .float32 ,
111118 )
112119 return out
0 commit comments