Skip to content

Commit ad71ce6

Browse files
authored
Merge branch 'main' into ltx2_lora
2 parents 3d36de2 + 993a5a6 commit ad71ce6

9 files changed

Lines changed: 999 additions & 21 deletions

File tree

src/maxdiffusion/checkpointing/ltx2_checkpointer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
7979
return restored_checkpoint, step
8080

8181
def load_checkpoint(
82-
self, step=None, vae_only=False, load_transformer=True
82+
self, step=None, vae_only=False, load_transformer=True, load_upsampler=False
8383
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
8484
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
8585
opt_state = None
8686

8787
if restored_checkpoint:
8888
max_logging.log("Loading LTX2 pipeline from checkpoint")
89-
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
89+
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer, load_upsampler)
9090
if "opt_state" in restored_checkpoint.ltx2_state.keys():
9191
opt_state = restored_checkpoint.ltx2_state["opt_state"]
9292
else:
9393
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
94-
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
94+
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer, load_upsampler)
9595

9696
return pipeline, opt_state, step
9797

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,12 @@ lora_config: {
123123
rank: [32]
124124
}
125125

126+
# LTX-2 Latent Upsampler
127+
run_latent_upsampler: False
128+
upsampler_model_path: "Lightricks/LTX-2"
129+
upsampler_spatial_patch_size: 1
130+
upsampler_temporal_patch_size: 1
131+
upsampler_adain_factor: 0.0
132+
upsampler_tone_map_compression_ratio: 0.0
133+
upsampler_rational_spatial_scale: 2.0
134+
upsampler_output_type: "pil"

src/maxdiffusion/generate_ltx2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def get_git_commit_hash():
8282

8383

8484
def call_pipeline(config, pipeline, prompt, negative_prompt):
85-
# Set default generation arguments
8685
generator = jax.random.key(config.seed) if hasattr(config, "seed") else jax.random.key(0)
8786
guidance_scale = config.guidance_scale if hasattr(config, "guidance_scale") else 3.0
8887

@@ -100,6 +99,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
10099
decode_noise_scale=getattr(config, "decode_noise_scale", None),
101100
max_sequence_length=getattr(config, "max_sequence_length", 1024),
102101
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
102+
output_type=getattr(config, "upsampler_output_type", "pil"),
103103
)
104104
return out
105105

@@ -115,9 +115,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
115115
else:
116116
max_logging.log("Could not retrieve Git commit hash.")
117117

118+
checkpoint_loader = LTX2Checkpointer(config=config)
118119
if pipeline is None:
119-
checkpoint_loader = LTX2Checkpointer(config=config)
120-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
120+
# Use the config flag to determine if the upsampler should be loaded
121+
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
122+
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
121123

122124
# If LoRA is specified, inject layers and load weights.
123125
if (
@@ -161,6 +163,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
161163
)
162164

163165
out = call_pipeline(config, pipeline, prompt, negative_prompt)
166+
164167
# out should have .frames and .audio
165168
videos = out.frames if hasattr(out, "frames") else out[0]
166169
audios = out.audio if hasattr(out, "audio") else None
@@ -169,6 +172,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
169172
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
170173
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
171174
max_logging.log(f"model type: {getattr(config, 'model_type', 'T2V')}")
175+
if getattr(config, "run_latent_upsampler", False):
176+
max_logging.log(f"upsampler model path: {config.upsampler_model_path}")
172177
max_logging.log(f"hardware: {jax.devices()[0].platform}")
173178
max_logging.log(f"number of devices: {jax.device_count()}")
174179
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")

0 commit comments

Comments
 (0)