Skip to content

Commit 6a77237

Browse files
committed
Upsampler pipeline
1 parent 9697900 commit 6a77237

8 files changed

Lines changed: 764 additions & 17 deletions

File tree

src/maxdiffusion/checkpointing/ltx2_checkpointer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,23 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
7979
return restored_checkpoint, step
8080

8181
def load_checkpoint(
82-
self, step=None, vae_only=False, load_transformer=True
82+
self, step=None, vae_only=False, load_transformer=True, load_upsampler=False
8383
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
8484
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
8585
opt_state = None
8686

8787
if restored_checkpoint:
8888
max_logging.log("Loading LTX2 pipeline from checkpoint")
89-
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
89+
pipeline = LTX2Pipeline.from_checkpoint(
90+
self.config, restored_checkpoint, vae_only, load_transformer, load_upsampler
91+
)
9092
if "opt_state" in restored_checkpoint.ltx2_state.keys():
9193
opt_state = restored_checkpoint.ltx2_state["opt_state"]
9294
else:
9395
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
94-
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
96+
pipeline = LTX2Pipeline.from_pretrained(
97+
self.config, vae_only, load_transformer, load_upsampler
98+
)
9599

96100
return pipeline, opt_state, step
97101

@@ -110,4 +114,4 @@ def config_to_json(model_or_config):
110114

111115
# Save the checkpoint
112116
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
113-
max_logging.log(f"Checkpoint for step {train_step} saved.")
117+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/compare.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import jax
3+
import jax.numpy as jnp
4+
import numpy as np
5+
6+
# 1. ALIAS THE IMPORTS to prevent name collisions!
7+
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel as PT_Upsampler
8+
from maxdiffusion.models.ltx2.latent_upsampler_ltx2 import LTX2LatentUpsamplerModel as JAX_Upsampler
9+
from maxdiffusion.models.ltx2.ltx2_utils import load_upsampler_weights
10+
11+
def test_side_by_side():
12+
# --- Setup PyTorch ---
13+
print("Initializing PyTorch Model...")
14+
# Load the real pretrained weights
15+
pt_model = PT_Upsampler.from_pretrained("Lightricks/LTX-2", subfolder="latent_upsampler")
16+
pt_model.eval()
17+
18+
# --- Setup JAX ---
19+
print("Initializing JAX Model...")
20+
jax_model = JAX_Upsampler()
21+
22+
print("Loading JAX Weights from HuggingFace...")
23+
# Use your actual conversion script to load the exact same weights
24+
flax_params = load_upsampler_weights(
25+
pretrained_model_name_or_path="Lightricks/LTX-2",
26+
eval_shapes=None,
27+
device="cpu", # Load into CPU for comparison
28+
subfolder="latent_upsampler"
29+
)
30+
31+
# for key, value in jax.tree_util.tree_flatten(flax_params)[0]:
32+
# if hasattr(value, 'dtype'):
33+
# print(f"{key}: {value.dtype}, shape: {value.shape}")
34+
35+
# --- Generate Identical Dummy Data ---
36+
# Shape: Batch=1, Channels=128, Frames=8, Height=32, Width=32
37+
print("Generating identical random inputs...")
38+
torch.manual_seed(42)
39+
pt_input = torch.randn(1, 128, 8, 32, 32, dtype=torch.float32)
40+
41+
# Convert PyTorch NCDHW -> JAX NDHWC
42+
# (0, 2, 3, 4, 1) maps (B, C, F, H, W) -> (B, F, H, W, C)
43+
jax_input_np = pt_input.permute(0, 2, 3, 4, 1).numpy()
44+
jax_input = jnp.array(jax_input_np)
45+
46+
# --- Run Forward Passes ---
47+
print("Running PyTorch pass...")
48+
with torch.no_grad():
49+
pt_output = pt_model(pt_input)
50+
51+
print("Running JAX pass...")
52+
jax_output = jax_model.apply({'params': flax_params}, jax_input)
53+
54+
# --- Compare Results ---
55+
# Convert JAX output back to PyTorch shape: NDHWC -> NCDHW
56+
# (0, 4, 1, 2, 3) maps (B, F, H, W, C) -> (B, C, F, H, W)
57+
jax_output_converted = torch.tensor(np.array(jax_output)).permute(0, 4, 1, 2, 3)
58+
59+
# Calculate Mean Squared Error (MSE) and Max Absolute Difference
60+
mse = torch.nn.functional.mse_loss(pt_output, jax_output_converted)
61+
max_diff = (pt_output - jax_output_converted).abs().max()
62+
63+
print("\n" + "="*30)
64+
print(" COMPARISON RESULTS ")
65+
print("="*30)
66+
print(f"Mean Squared Error: {mse.item():.8f}")
67+
print(f"Max Absolute Error: {max_diff.item():.8f}")
68+
69+
if max_diff.item() < 1e-3:
70+
print("\n✅ SUCCESS: The models are mathematically identical!")
71+
else:
72+
print("\n❌ FAILED: The models diverge. There is a bug in the math/weights.")
73+
74+
if __name__ == "__main__":
75+
test_side_by_side()

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ names_which_can_be_saved: []
99
names_which_can_be_offloaded: []
1010
remat_policy: "NONE"
1111

12-
jax_cache_dir: ''
12+
jax_cache_dir: '/mnt/disks/mehdy-disk1/maxdiffusion_hf_cache'
1313
weights_dtype: 'bfloat16'
1414
activations_dtype: 'bfloat16'
1515

@@ -92,3 +92,12 @@ jit_initializers: True
9292
enable_single_replica_ckpt_restoring: False
9393
seed: 0
9494
audio_format: "s16"
95+
96+
# LTX-2 Latent Upsampler
97+
run_latent_upsampler: False
98+
upsampler_model_path: "Lightricks/LTX-2"
99+
upsampler_spatial_patch_size: 1
100+
upsampler_temporal_patch_size: 1
101+
upsampler_adain_factor: 0.0
102+
upsampler_tone_map_compression_ratio: 0.0
103+
upsampler_rational_spatial_scale: 2.0

src/maxdiffusion/generate_ltx2.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Sequence
1616
import jax
1717
import jax.numpy as jnp
18+
import numpy as np
1819
import time
1920
import os
2021
import subprocess
@@ -81,9 +82,9 @@ def get_git_commit_hash():
8182

8283

8384
def call_pipeline(config, pipeline, prompt, negative_prompt):
84-
# Set default generation arguments
8585
generator = jax.random.key(config.seed) if hasattr(config, "seed") else jax.random.key(0)
8686
guidance_scale = config.guidance_scale if hasattr(config, "guidance_scale") else 3.0
87+
output_type = "pil"
8788

8889
out = pipeline(
8990
prompt=prompt,
@@ -99,7 +100,9 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
99100
decode_noise_scale=getattr(config, "decode_noise_scale", None),
100101
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101102
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
103+
output_type=output_type,
102104
)
105+
103106
return out
104107

105108

@@ -114,9 +117,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
114117
else:
115118
max_logging.log("Could not retrieve Git commit hash.")
116119

120+
checkpoint_loader = LTX2Checkpointer(config=config)
117121
if pipeline is None:
118-
checkpoint_loader = LTX2Checkpointer(config=config)
119-
pipeline, _, _ = checkpoint_loader.load_checkpoint()
122+
# Use the config flag to determine if the upsampler should be loaded
123+
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
124+
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
120125

121126
pipeline.enable_vae_slicing()
122127
pipeline.enable_vae_tiling()
@@ -133,8 +138,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
133138
max_logging.log(
134139
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
135140
)
136-
141+
137142
out = call_pipeline(config, pipeline, prompt, negative_prompt)
143+
138144
# out should have .frames and .audio
139145
videos = out.frames if hasattr(out, "frames") else out[0]
140146
audios = out.audio if hasattr(out, "audio") else None
@@ -143,6 +149,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
143149
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
144150
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
145151
max_logging.log(f"model type: {getattr(config, 'model_type', 'T2V')}")
152+
if getattr(config, "run_latent_upsampler", False):
153+
max_logging.log(f"upsampler model path: {config.upsampler_model_path}")
146154
max_logging.log(f"hardware: {jax.devices()[0].platform}")
147155
max_logging.log(f"number of devices: {jax.device_count()}")
148156
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
@@ -218,4 +226,4 @@ def main(argv: Sequence[str]) -> None:
218226

219227

220228
if __name__ == "__main__":
221-
app.run(main)
229+
app.run(main)

0 commit comments

Comments
 (0)