Skip to content

Commit 80e8923

Browse files
committed
adding parity test file
1 parent 42cf3e1 commit 80e8923

1 file changed

Lines changed: 153 additions & 0 deletions

File tree

parity_ltx2_maxdiffusion.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import os
2+
import sys
3+
4+
# Ensure we use the local maxdiffusion src directory so we pull from the repository
5+
sys.path.insert(0, os.path.abspath("src"))
6+
7+
# Must patch jax.random.normal before importing anything that uses it
8+
import jax
9+
import jax.numpy as jnp
10+
import numpy as np
11+
import torch
12+
13+
orig_normal = jax.random.normal
14+
video_noise = None
15+
audio_noise = None
16+
17+
def load_noises():
18+
global video_noise, audio_noise
19+
video_noise_path = "video_noise.npy"
20+
audio_noise_path = "audio_noise.npy"
21+
if os.path.exists(video_noise_path):
22+
video_noise = np.load(video_noise_path)
23+
else:
24+
print(f"Warning: {video_noise_path} not found")
25+
26+
if os.path.exists(audio_noise_path):
27+
audio_noise = np.load(audio_noise_path)
28+
else:
29+
print(f"Warning: {audio_noise_path} not found")
30+
31+
def custom_normal(key, shape, dtype=None, **kwargs):
32+
if len(shape) == 5 and video_noise is not None:
33+
return jnp.array(video_noise, dtype=dtype)
34+
if len(shape) == 4 and audio_noise is not None:
35+
return jnp.array(audio_noise, dtype=dtype)
36+
return orig_normal(key, shape, dtype=dtype, **kwargs)
37+
38+
jax.random.normal = custom_normal
39+
40+
from maxdiffusion import pyconfig
41+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
42+
import maxdiffusion.pipelines.ltx2.ltx2_pipeline as pipe_module
43+
from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
44+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
45+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio
46+
from maxdiffusion.models.ltx2.vocoder_ltx2 import LTX2Vocoder
47+
import transformers
48+
49+
# Patch pipeline module's random normal if needed
50+
pipe_module.jax.random.normal = custom_normal
51+
52+
def print_stat(name, t):
53+
t_np = np.array(t, dtype=np.float32)
54+
print(f"[{name}] min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}")
55+
56+
# Patch transformer forward pass
57+
orig_transformer_forward_pass = pipe_module.transformer_forward_pass
58+
def patched_transformer_forward_pass(*args, **kwargs):
59+
noise_pred, noise_pred_audio = orig_transformer_forward_pass(*args, **kwargs)
60+
print_stat("transformer_video", noise_pred)
61+
print_stat("transformer_audio", noise_pred_audio)
62+
return noise_pred, noise_pred_audio
63+
pipe_module.transformer_forward_pass = patched_transformer_forward_pass
64+
65+
# Patch Gemma
66+
orig_gemma_call = transformers.Gemma3ForConditionalGeneration.forward
67+
def patched_gemma_call(self, *args, **kwargs):
68+
out = orig_gemma_call(self, *args, **kwargs)
69+
if hasattr(out, "hidden_states") and out.hidden_states:
70+
print_stat("text_encoder", out.hidden_states[-1])
71+
elif isinstance(out, (list, tuple)):
72+
print_stat("text_encoder", out[0])
73+
return out
74+
transformers.Gemma3ForConditionalGeneration.forward = patched_gemma_call
75+
76+
# Patch Connectors
77+
orig_connector_call = LTX2AudioVideoGemmaTextEncoder.__call__
78+
def patched_connector_call(self, hidden_states, attention_mask):
79+
out = orig_connector_call(self, hidden_states, attention_mask)
80+
print_stat("connectors_video", out[0])
81+
print_stat("connectors_audio", out[1])
82+
return out
83+
LTX2AudioVideoGemmaTextEncoder.__call__ = patched_connector_call
84+
85+
# Patch VAE Decoder
86+
orig_vae_decode = LTX2VideoAutoencoderKL.decode
87+
def patched_vae_decode(self, *args, **kwargs):
88+
out = orig_vae_decode(self, *args, **kwargs)
89+
if isinstance(out, (tuple, list)):
90+
print_stat("vae_decoder", out[0])
91+
else:
92+
print_stat("vae_decoder", out)
93+
return out
94+
LTX2VideoAutoencoderKL.decode = patched_vae_decode
95+
96+
# Patch Audio VAE Decoder
97+
orig_audio_vae_decode = FlaxAutoencoderKLLTX2Audio.decode
98+
def patched_audio_vae_decode(self, *args, **kwargs):
99+
out = orig_audio_vae_decode(self, *args, **kwargs)
100+
if isinstance(out, (tuple, list)):
101+
print_stat("audio_vae_decoder", out[0])
102+
else:
103+
print_stat("audio_vae_decoder", out)
104+
return out
105+
FlaxAutoencoderKLLTX2Audio.decode = patched_audio_vae_decode
106+
107+
# Patch Vocoder
108+
orig_vocoder_call = LTX2Vocoder.__call__
109+
def patched_vocoder_call(self, *args, **kwargs):
110+
out = orig_vocoder_call(self, *args, **kwargs)
111+
print_stat("vocoder", out)
112+
return out
113+
LTX2Vocoder.__call__ = patched_vocoder_call
114+
115+
116+
def main():
117+
load_noises()
118+
119+
# Init pyconfig, this assumes the user runs: python parity_ltx2_maxdiffusion.py src/maxdiffusion/configs/ltx2_video.yml
120+
if len(sys.argv) < 2:
121+
print("Please provide the path to ltx2_video.yml")
122+
sys.exit(1)
123+
124+
pyconfig.initialize(sys.argv)
125+
config = pyconfig.config
126+
127+
# Create the pipeline
128+
pipe = LTX2Pipeline.from_pretrained(config)
129+
130+
prompt = getattr(config, "prompt", "A man in a brightly lit room...")
131+
negative_prompt = getattr(config, "negative_prompt", "shaky, glitchy, low quality...")
132+
height = getattr(config, "height", 512)
133+
width = getattr(config, "width", 768)
134+
num_frames = getattr(config, "num_frames", 121)
135+
frame_rate = 24.0
136+
137+
print("Running MaxDiffusion pipeline...")
138+
out = pipe(
139+
prompt=prompt,
140+
negative_prompt=negative_prompt,
141+
height=height,
142+
width=width,
143+
num_frames=num_frames,
144+
frame_rate=frame_rate,
145+
num_inference_steps=getattr(config, "num_inference_steps", 40),
146+
guidance_scale=getattr(config, "guidance_scale", 3.0),
147+
output_type="np", # ensures VAE is called
148+
return_dict=False
149+
)
150+
print("Done")
151+
152+
if __name__ == '__main__':
153+
main()

0 commit comments

Comments
 (0)