|
| 1 | +import os |
| 2 | +import sys |
| 3 | + |
| 4 | +# Ensure we use the local maxdiffusion src directory so we pull from the repository |
| 5 | +sys.path.insert(0, os.path.abspath("src")) |
| 6 | + |
| 7 | +import jax |
| 8 | +import jax.numpy as jnp |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +orig_normal = jax.random.normal |
| 12 | +video_noise = None |
| 13 | +audio_noise = None |
| 14 | + |
| 15 | +def load_noises(): |
| 16 | + global video_noise, audio_noise |
| 17 | + video_noise_path = "video_noise.npy" |
| 18 | + if os.path.exists(video_noise_path): |
| 19 | + video_noise = np.load(video_noise_path) |
| 20 | + else: |
| 21 | + print(f"Warning: {video_noise_path} not found") |
| 22 | + |
| 23 | + audio_noise_path = "audio_noise.npy" |
| 24 | + if os.path.exists(audio_noise_path): |
| 25 | + audio_noise = np.load(audio_noise_path) |
| 26 | + else: |
| 27 | + print(f"Warning: {audio_noise_path} not found") |
| 28 | + |
| 29 | +def custom_normal(key, shape, dtype=None, **kwargs): |
| 30 | + if len(shape) == 5 and video_noise is not None: |
| 31 | + return jnp.array(video_noise, dtype=dtype) |
| 32 | + if len(shape) == 4 and audio_noise is not None: |
| 33 | + return jnp.array(audio_noise, dtype=dtype) |
| 34 | + return orig_normal(key, shape, dtype=dtype, **kwargs) |
| 35 | + |
| 36 | +jax.random.normal = custom_normal |
| 37 | + |
| 38 | +from maxdiffusion import pyconfig |
| 39 | +from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline |
| 40 | +import maxdiffusion.pipelines.ltx2.ltx2_pipeline as pipe_module |
| 41 | +from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder |
| 42 | + |
| 43 | +pipe_module.jax.random.normal = custom_normal |
| 44 | + |
| 45 | +def _print_stat_impl(name, t): |
| 46 | + if hasattr(t, "cpu"): |
| 47 | + t = t.detach().cpu().float().numpy() |
| 48 | + t_np = np.array(t, dtype=np.float32) |
| 49 | + print(f"[{name}] shape: {t_np.shape}, min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}") |
| 50 | + |
| 51 | +def print_stat(name, t): |
| 52 | + if isinstance(t, jax.core.Tracer): |
| 53 | + jax.debug.callback(_print_stat_impl, name, t) |
| 54 | + else: |
| 55 | + _print_stat_impl(name, t) |
| 56 | + |
| 57 | +# Patch Connectors |
| 58 | +orig_connector_call = LTX2AudioVideoGemmaTextEncoder.__call__ |
| 59 | +def patched_connector_call(self, hidden_states, attention_mask): |
| 60 | + out = orig_connector_call(self, hidden_states, attention_mask) |
| 61 | + print("\n=== CONNECTORS OUTPUTS ===") |
| 62 | + print_stat("connectors_video", out[0]) |
| 63 | + print_stat("connectors_audio", out[1]) |
| 64 | + return out |
| 65 | +LTX2AudioVideoGemmaTextEncoder.__call__ = patched_connector_call |
| 66 | + |
| 67 | +# Patch Transformer forward pass to intercept inputs and EXIT EARLY |
| 68 | +orig_transformer_forward_pass = pipe_module.transformer_forward_pass |
| 69 | +def patched_transformer_forward_pass(*args, **kwargs): |
| 70 | + print("\n=== TRANSFORMER INPUTS (MAXDIFFUSION) ===") |
| 71 | + |
| 72 | + # In Maxdiffusion, args are usually (hidden_states, encoder_hidden_states, timestep, ...) |
| 73 | + if "hidden_states" in kwargs: |
| 74 | + print_stat("transformer_input_video_latents", kwargs["hidden_states"]) |
| 75 | + elif len(args) > 0 and args[0] is not None: |
| 76 | + print_stat("transformer_input_video_latents", args[0]) |
| 77 | + |
| 78 | + if "encoder_hidden_states" in kwargs: |
| 79 | + print_stat("transformers_encoder_hidden_states", kwargs["encoder_hidden_states"]) |
| 80 | + elif len(args) > 1 and args[1] is not None: |
| 81 | + print_stat("transformers_encoder_hidden_states", args[1]) |
| 82 | + |
| 83 | + if "timestep" in kwargs: |
| 84 | + print_stat("transformer_timestep", kwargs["timestep"]) |
| 85 | + elif len(args) > 2 and args[2] is not None: |
| 86 | + print_stat("transformer_timestep", args[2]) |
| 87 | + |
| 88 | + if "audio_hidden_states" in kwargs: |
| 89 | + print_stat("transformer_input_audio_latents", kwargs["audio_hidden_states"]) |
| 90 | + elif len(args) > 3 and args[3] is not None: |
| 91 | + print_stat("transformer_input_audio_latents", args[3]) |
| 92 | + |
| 93 | + if "audio_encoder_hidden_states" in kwargs: |
| 94 | + print_stat("transformers_audio_encoder_hidden_states", kwargs["audio_encoder_hidden_states"]) |
| 95 | + elif len(args) > 4 and args[4] is not None: |
| 96 | + print_stat("transformers_audio_encoder_hidden_states", args[4]) |
| 97 | + |
| 98 | + print("\n[SUCCESS] Captured all inputs up to Transformer logic. Exiting early to save compute.\n") |
| 99 | + import os |
| 100 | + os._exit(0) |
| 101 | +pipe_module.transformer_forward_pass = patched_transformer_forward_pass |
| 102 | + |
| 103 | +def main(): |
| 104 | + load_noises() |
| 105 | + |
| 106 | + # Init pyconfig, this assumes the user runs: python before_transformer_parity_maxdiffusion.py src/maxdiffusion/configs/ltx2_video.yml |
| 107 | + if len(sys.argv) < 2: |
| 108 | + print("Please provide the path to ltx2_video.yml") |
| 109 | + sys.exit(1) |
| 110 | + |
| 111 | + pyconfig.initialize(sys.argv) |
| 112 | + config = pyconfig.config |
| 113 | + |
| 114 | + # Create the pipeline |
| 115 | + pipe = LTX2Pipeline.from_pretrained(config) |
| 116 | + |
| 117 | + prompt = getattr(config, "prompt", "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie.") |
| 118 | + if not isinstance(prompt, str): |
| 119 | + if isinstance(prompt, (list, tuple)): |
| 120 | + prompt = ", ".join(str(p) for p in prompt) |
| 121 | + else: |
| 122 | + prompt = str(prompt) |
| 123 | + prompt = [prompt] # Pass as list to avoid pipeline encode_prompt type validation bug |
| 124 | + |
| 125 | + negative_prompt = getattr(config, "negative_prompt", "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.") |
| 126 | + if not isinstance(negative_prompt, str): |
| 127 | + if isinstance(negative_prompt, (list, tuple)): |
| 128 | + negative_prompt = ", ".join(str(p) for p in negative_prompt) |
| 129 | + else: |
| 130 | + negative_prompt = str(negative_prompt) |
| 131 | + negative_prompt = [negative_prompt] |
| 132 | + |
| 133 | + height = getattr(config, "height", 512) |
| 134 | + width = getattr(config, "width", 768) |
| 135 | + num_frames = getattr(config, "num_frames", 121) |
| 136 | + frame_rate = 24.0 |
| 137 | + |
| 138 | + print("Running MaxDiffusion pipeline...") |
| 139 | + out = pipe( |
| 140 | + prompt=prompt, |
| 141 | + negative_prompt=negative_prompt, |
| 142 | + height=height, |
| 143 | + width=width, |
| 144 | + num_frames=num_frames, |
| 145 | + frame_rate=frame_rate, |
| 146 | + num_inference_steps=getattr(config, "num_inference_steps", 40), |
| 147 | + guidance_scale=getattr(config, "guidance_scale", 3.0), |
| 148 | + output_type="np", # ensures VAE is called |
| 149 | + return_dict=False |
| 150 | + ) |
| 151 | + print("Done") |
| 152 | + |
| 153 | +if __name__ == '__main__': |
| 154 | + main() |
0 commit comments