Skip to content

Commit 049940e

Browse files
committed
fix cross attention
1 parent 7d7607a commit 049940e

2 files changed

Lines changed: 110 additions & 96 deletions

File tree

src/maxdiffusion/generate_wan_animate.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
# Copyright 2026 Google LLC
22
# Licensed under the Apache License, Version 2.0 (the "License");
33

4-
import jax
4+
"""Wan Animate inference entrypoint."""
5+
56
import os
67
import time
8+
79
from absl import app
8-
from maxdiffusion import pyconfig, max_logging, max_utils
9-
from maxdiffusion.train_utils import transformer_engine_context
10-
from maxdiffusion.utils import export_to_video
11-
from maxdiffusion.utils.loading_utils import load_image, load_video
1210
import flax
13-
from maxdiffusion.pipelines.wan.wan_pipeline_animate import WanAnimatePipeline
11+
import jax
1412
import numpy as np
1513
from PIL import Image
1614

15+
from maxdiffusion import max_logging, max_utils, pyconfig
16+
from maxdiffusion.pipelines.wan.wan_pipeline_animate import WanAnimatePipeline
17+
from maxdiffusion.train_utils import transformer_engine_context
18+
from maxdiffusion.utils import export_to_video
19+
from maxdiffusion.utils.loading_utils import load_image, load_video
20+
1721
jax.config.update("jax_use_shardy_partitioner", True)
1822

1923

2024
def _get_animate_inference_settings(config):
2125
"""Resolve animate-specific inference settings with upstream defaults."""
2226
return {
2327
"segment_frame_length": getattr(config, "segment_frame_length", 77),
24-
"prev_segment_conditioning_frames": getattr(config, "prev_segment_conditioning_frames", 1),
28+
"prev_segment_conditioning_frames": getattr(config, "prev_segment_conditioning_frames", 5),
2529
"motion_encode_batch_size": getattr(config, "motion_encode_batch_size", None),
2630
"guidance_scale": getattr(config, "animate_guidance_scale", 1.0),
2731
}
@@ -35,6 +39,7 @@ def _frame_summary(name, frames):
3539

3640

3741
def run(config):
42+
"""Run Wan Animate inference and write the generated videos to disk."""
3843
writer = max_utils.initialize_summary_writer(config)
3944
if jax.process_index() == 0 and writer:
4045
max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}")
@@ -68,9 +73,7 @@ def run(config):
6873
motion_encoder_size = pipeline.transformer.config.motion_encoder_size
6974

7075
if pose_video_path and face_video_path:
71-
max_logging.log(
72-
f"Loading preprocessed videos from disk. pose_video={pose_video_path}, face_video={face_video_path}"
73-
)
76+
max_logging.log(f"Loading preprocessed videos from disk. pose_video={pose_video_path}, face_video={face_video_path}")
7477
pose_video = load_video(pose_video_path)
7578
face_video = load_video(face_video_path)
7679
num_frames = min(num_frames, len(pose_video), len(face_video))
@@ -85,7 +88,9 @@ def run(config):
8588
"For real outputs provide preprocessed pose_video_path and face_video_path."
8689
)
8790
pose_video = [Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8)) for _ in range(num_frames)]
88-
face_video = [Image.fromarray(np.zeros((motion_encoder_size, motion_encoder_size, 3), dtype=np.uint8)) for _ in range(num_frames)]
91+
face_video = [
92+
Image.fromarray(np.zeros((motion_encoder_size, motion_encoder_size, 3), dtype=np.uint8)) for _ in range(num_frames)
93+
]
8994

9095
background_video = None
9196
mask_video = None
@@ -96,47 +101,37 @@ def run(config):
96101
mask_video = load_video(mask_video_path)[:num_frames]
97102

98103
max_logging.log(
99-
"Wan animate inputs: reference_image=%s, image_size=%s, pose_video_path=%s, face_video_path=%s, %s, %s"
100-
% (
101-
reference_image_source,
102-
getattr(image, "size", None),
103-
pose_video_path or "<dummy>",
104-
face_video_path or "<dummy>",
105-
_frame_summary("pose", pose_video),
106-
_frame_summary("face", face_video),
107-
)
104+
"Wan animate inputs: "
105+
f"reference_image={reference_image_source}, "
106+
f"image_size={getattr(image, 'size', None)}, "
107+
f"pose_video_path={pose_video_path or '<dummy>'}, "
108+
f"face_video_path={face_video_path or '<dummy>'}, "
109+
f"{_frame_summary('pose', pose_video)}, "
110+
f"{_frame_summary('face', face_video)}"
108111
)
109112
if mode == "replace":
110113
max_logging.log(
111-
"Wan replace inputs: background_video_path=%s, mask_video_path=%s, %s, %s"
112-
% (
113-
background_video_path,
114-
mask_video_path,
115-
_frame_summary("background", background_video),
116-
_frame_summary("mask", mask_video),
117-
)
114+
"Wan replace inputs: "
115+
f"background_video_path={background_video_path}, "
116+
f"mask_video_path={mask_video_path}, "
117+
f"{_frame_summary('background', background_video)}, "
118+
f"{_frame_summary('mask', mask_video)}"
118119
)
119120

120121
animate_settings = _get_animate_inference_settings(config)
121122
prompt = config.prompt
122123
negative_prompt = config.negative_prompt if animate_settings["guidance_scale"] > 1.0 else None
123124

124125
max_logging.log(
125-
"Num steps: %s, height: %s, width: %s, frames: %s, segment_frame_length: %s, "
126-
"prev_segment_conditioning_frames: %s, guidance_scale: %s"
127-
% (
128-
config.num_inference_steps,
129-
height,
130-
width,
131-
num_frames,
132-
animate_settings["segment_frame_length"],
133-
animate_settings["prev_segment_conditioning_frames"],
134-
animate_settings["guidance_scale"],
135-
)
126+
"Num steps: "
127+
f"{config.num_inference_steps}, height: {height}, width: {width}, frames: {num_frames}, "
128+
f"segment_frame_length: {animate_settings['segment_frame_length']}, "
129+
f"prev_segment_conditioning_frames: {animate_settings['prev_segment_conditioning_frames']}, "
130+
f"guidance_scale: {animate_settings['guidance_scale']}"
136131
)
137132

138133
s0 = time.perf_counter()
139-
134+
140135
# First pass (compile)
141136
videos = pipeline(
142137
image=image,
@@ -155,7 +150,7 @@ def run(config):
155150
num_inference_steps=config.num_inference_steps,
156151
mode=mode,
157152
)
158-
153+
159154
compile_time = time.perf_counter() - s0
160155
max_logging.log(f"compile_time: {compile_time}")
161156
if writer and jax.process_index() == 0:
@@ -179,17 +174,17 @@ def run(config):
179174
num_inference_steps=config.num_inference_steps,
180175
mode=mode,
181176
)
182-
177+
183178
generation_time = time.perf_counter() - s0
184179
max_logging.log(f"generation_time: {generation_time}")
185180
if writer and jax.process_index() == 0:
186181
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
187182

188183
filename_prefix = "animate_"
189184
os.makedirs(config.output_dir, exist_ok=True)
190-
for i in range(len(videos)):
185+
for i, video in enumerate(videos):
191186
video_path = os.path.join(config.output_dir, f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
192-
export_to_video(videos[i], video_path, fps=config.fps)
187+
export_to_video(video, video_path, fps=config.fps)
193188
max_logging.log(f"Saved video to {video_path}")
194189

195190
if getattr(config, "enable_profiler", False):
@@ -220,6 +215,7 @@ def run(config):
220215

221216
return videos
222217

218+
223219
def main(argv) -> None:
224220
pyconfig.initialize(argv)
225221
try:
@@ -228,6 +224,7 @@ def main(argv) -> None:
228224
pass
229225
run(pyconfig.config)
230226

227+
231228
if __name__ == "__main__":
232229
with transformer_engine_context():
233230
app.run(main)

0 commit comments

Comments
 (0)