Skip to content

Commit 7ebe7fb

Browse files
committed
before_transformer parity file
1 parent df7e8dc commit 7ebe7fb

1 file changed

Lines changed: 154 additions & 0 deletions

File tree

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import os
2+
import sys
3+
4+
# Ensure we use the local maxdiffusion src directory so we pull from the repository
5+
sys.path.insert(0, os.path.abspath("src"))
6+
7+
import jax
8+
import jax.numpy as jnp
9+
import numpy as np
10+
11+
orig_normal = jax.random.normal
12+
video_noise = None
13+
audio_noise = None
14+
15+
def load_noises():
16+
global video_noise, audio_noise
17+
video_noise_path = "video_noise.npy"
18+
if os.path.exists(video_noise_path):
19+
video_noise = np.load(video_noise_path)
20+
else:
21+
print(f"Warning: {video_noise_path} not found")
22+
23+
audio_noise_path = "audio_noise.npy"
24+
if os.path.exists(audio_noise_path):
25+
audio_noise = np.load(audio_noise_path)
26+
else:
27+
print(f"Warning: {audio_noise_path} not found")
28+
29+
def custom_normal(key, shape, dtype=None, **kwargs):
30+
if len(shape) == 5 and video_noise is not None:
31+
return jnp.array(video_noise, dtype=dtype)
32+
if len(shape) == 4 and audio_noise is not None:
33+
return jnp.array(audio_noise, dtype=dtype)
34+
return orig_normal(key, shape, dtype=dtype, **kwargs)
35+
36+
jax.random.normal = custom_normal
37+
38+
from maxdiffusion import pyconfig
39+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
40+
import maxdiffusion.pipelines.ltx2.ltx2_pipeline as pipe_module
41+
from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
42+
43+
pipe_module.jax.random.normal = custom_normal
44+
45+
def _print_stat_impl(name, t):
46+
if hasattr(t, "cpu"):
47+
t = t.detach().cpu().float().numpy()
48+
t_np = np.array(t, dtype=np.float32)
49+
print(f"[{name}] shape: {t_np.shape}, min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}")
50+
51+
def print_stat(name, t):
52+
if isinstance(t, jax.core.Tracer):
53+
jax.debug.callback(_print_stat_impl, name, t)
54+
else:
55+
_print_stat_impl(name, t)
56+
57+
# Patch Connectors
58+
orig_connector_call = LTX2AudioVideoGemmaTextEncoder.__call__
59+
def patched_connector_call(self, hidden_states, attention_mask):
60+
out = orig_connector_call(self, hidden_states, attention_mask)
61+
print("\n=== CONNECTORS OUTPUTS ===")
62+
print_stat("connectors_video", out[0])
63+
print_stat("connectors_audio", out[1])
64+
return out
65+
LTX2AudioVideoGemmaTextEncoder.__call__ = patched_connector_call
66+
67+
# Patch Transformer forward pass to intercept inputs and EXIT EARLY
68+
orig_transformer_forward_pass = pipe_module.transformer_forward_pass
69+
def patched_transformer_forward_pass(*args, **kwargs):
70+
print("\n=== TRANSFORMER INPUTS (MAXDIFFUSION) ===")
71+
72+
# In Maxdiffusion, args are usually (hidden_states, encoder_hidden_states, timestep, ...)
73+
if "hidden_states" in kwargs:
74+
print_stat("transformer_input_video_latents", kwargs["hidden_states"])
75+
elif len(args) > 0 and args[0] is not None:
76+
print_stat("transformer_input_video_latents", args[0])
77+
78+
if "encoder_hidden_states" in kwargs:
79+
print_stat("transformers_encoder_hidden_states", kwargs["encoder_hidden_states"])
80+
elif len(args) > 1 and args[1] is not None:
81+
print_stat("transformers_encoder_hidden_states", args[1])
82+
83+
if "timestep" in kwargs:
84+
print_stat("transformer_timestep", kwargs["timestep"])
85+
elif len(args) > 2 and args[2] is not None:
86+
print_stat("transformer_timestep", args[2])
87+
88+
if "audio_hidden_states" in kwargs:
89+
print_stat("transformer_input_audio_latents", kwargs["audio_hidden_states"])
90+
elif len(args) > 3 and args[3] is not None:
91+
print_stat("transformer_input_audio_latents", args[3])
92+
93+
if "audio_encoder_hidden_states" in kwargs:
94+
print_stat("transformers_audio_encoder_hidden_states", kwargs["audio_encoder_hidden_states"])
95+
elif len(args) > 4 and args[4] is not None:
96+
print_stat("transformers_audio_encoder_hidden_states", args[4])
97+
98+
print("\n[SUCCESS] Captured all inputs up to Transformer logic. Exiting early to save compute.\n")
99+
import os
100+
os._exit(0)
101+
pipe_module.transformer_forward_pass = patched_transformer_forward_pass
102+
103+
def main():
104+
load_noises()
105+
106+
# Init pyconfig, this assumes the user runs: python before_transformer_parity_maxdiffusion.py src/maxdiffusion/configs/ltx2_video.yml
107+
if len(sys.argv) < 2:
108+
print("Please provide the path to ltx2_video.yml")
109+
sys.exit(1)
110+
111+
pyconfig.initialize(sys.argv)
112+
config = pyconfig.config
113+
114+
# Create the pipeline
115+
pipe = LTX2Pipeline.from_pretrained(config)
116+
117+
prompt = getattr(config, "prompt", "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie.")
118+
if not isinstance(prompt, str):
119+
if isinstance(prompt, (list, tuple)):
120+
prompt = ", ".join(str(p) for p in prompt)
121+
else:
122+
prompt = str(prompt)
123+
prompt = [prompt] # Pass as list to avoid pipeline encode_prompt type validation bug
124+
125+
negative_prompt = getattr(config, "negative_prompt", "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.")
126+
if not isinstance(negative_prompt, str):
127+
if isinstance(negative_prompt, (list, tuple)):
128+
negative_prompt = ", ".join(str(p) for p in negative_prompt)
129+
else:
130+
negative_prompt = str(negative_prompt)
131+
negative_prompt = [negative_prompt]
132+
133+
height = getattr(config, "height", 512)
134+
width = getattr(config, "width", 768)
135+
num_frames = getattr(config, "num_frames", 121)
136+
frame_rate = 24.0
137+
138+
print("Running MaxDiffusion pipeline...")
139+
out = pipe(
140+
prompt=prompt,
141+
negative_prompt=negative_prompt,
142+
height=height,
143+
width=width,
144+
num_frames=num_frames,
145+
frame_rate=frame_rate,
146+
num_inference_steps=getattr(config, "num_inference_steps", 40),
147+
guidance_scale=getattr(config, "guidance_scale", 3.0),
148+
output_type="np", # ensures VAE is called
149+
return_dict=False
150+
)
151+
print("Done")
152+
153+
if __name__ == '__main__':
154+
main()

0 commit comments

Comments
 (0)