Skip to content

Commit 097f4c3

Browse files
committed
Refactor WAN animate pipeline and cache static computations
1 parent cde8ab8 commit 097f4c3

5 files changed

Lines changed: 668 additions & 424 deletions

File tree

src/maxdiffusion/generate_wan_animate.py

Lines changed: 109 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -40,60 +40,65 @@ def run(config):
4040
max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}")
4141

4242
load_start = time.perf_counter()
43-
pipeline = WanAnimatePipeline.from_pretrained(config)
43+
with jax.profiler.TraceAnnotation("wan_animate_load_pipeline"):
44+
pipeline = WanAnimatePipeline.from_pretrained(config)
4445
load_time = time.perf_counter() - load_start
4546
max_logging.log(f"load_time: {load_time:.1f}s")
4647

4748
# Setup inputs
48-
reference_image_path = getattr(config, "reference_image_path", "")
49-
if reference_image_path:
50-
image = load_image(reference_image_path)
51-
reference_image_source = reference_image_path
52-
else:
53-
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
54-
image = load_image(image_url)
55-
reference_image_source = image_url
56-
57-
mode = getattr(config, "mode", "animate")
58-
pose_video_path = getattr(config, "pose_video_path", "")
59-
face_video_path = getattr(config, "face_video_path", "")
60-
background_video_path = getattr(config, "background_video_path", "")
61-
mask_video_path = getattr(config, "mask_video_path", "")
62-
63-
num_frames = config.num_frames
64-
height = config.height
65-
width = config.width
66-
67-
# face_video needs to match motion_encoder_size (probably 224x224 or 256x256)
68-
motion_encoder_size = pipeline.transformer.config.motion_encoder_size
69-
70-
if pose_video_path and face_video_path:
71-
max_logging.log(
72-
f"Loading preprocessed videos from disk. pose_video={pose_video_path}, face_video={face_video_path}"
73-
)
74-
pose_video = load_video(pose_video_path)
75-
face_video = load_video(face_video_path)
76-
num_frames = min(num_frames, len(pose_video), len(face_video))
77-
if num_frames == 0:
78-
raise ValueError("Loaded empty pose/face video. Check preprocessing outputs.")
79-
pose_video = pose_video[:num_frames]
80-
face_video = face_video[:num_frames]
81-
else:
82-
# Fallback path used for quick smoke tests only.
83-
max_logging.log(
84-
"No pose/face video paths provided; generating dummy videos for a smoke test only. "
85-
"For real outputs provide preprocessed pose_video_path and face_video_path."
86-
)
87-
pose_video = [Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8)) for _ in range(num_frames)]
88-
face_video = [Image.fromarray(np.zeros((motion_encoder_size, motion_encoder_size, 3), dtype=np.uint8)) for _ in range(num_frames)]
89-
90-
background_video = None
91-
mask_video = None
92-
if mode == "replace":
93-
if not background_video_path or not mask_video_path:
94-
raise ValueError("Replace mode requires both `background_video_path` and `mask_video_path`.")
95-
background_video = load_video(background_video_path)[:num_frames]
96-
mask_video = load_video(mask_video_path)[:num_frames]
49+
with jax.profiler.TraceAnnotation("wan_animate_prepare_inputs"):
50+
reference_image_path = getattr(config, "reference_image_path", "")
51+
if reference_image_path:
52+
image = load_image(reference_image_path)
53+
reference_image_source = reference_image_path
54+
else:
55+
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
56+
image = load_image(image_url)
57+
reference_image_source = image_url
58+
59+
mode = getattr(config, "mode", "animate")
60+
pose_video_path = getattr(config, "pose_video_path", "")
61+
face_video_path = getattr(config, "face_video_path", "")
62+
background_video_path = getattr(config, "background_video_path", "")
63+
mask_video_path = getattr(config, "mask_video_path", "")
64+
65+
num_frames = config.num_frames
66+
height = config.height
67+
width = config.width
68+
69+
# face_video needs to match motion_encoder_size (probably 224x224 or 256x256)
70+
motion_encoder_size = pipeline.transformer.config.motion_encoder_size
71+
72+
if pose_video_path and face_video_path:
73+
max_logging.log(
74+
f"Loading preprocessed videos from disk. pose_video={pose_video_path}, face_video={face_video_path}"
75+
)
76+
pose_video = load_video(pose_video_path)
77+
face_video = load_video(face_video_path)
78+
num_frames = min(num_frames, len(pose_video), len(face_video))
79+
if num_frames == 0:
80+
raise ValueError("Loaded empty pose/face video. Check preprocessing outputs.")
81+
pose_video = pose_video[:num_frames]
82+
face_video = face_video[:num_frames]
83+
else:
84+
# Fallback path used for quick smoke tests only.
85+
max_logging.log(
86+
"No pose/face video paths provided; generating dummy videos for a smoke test only. "
87+
"For real outputs provide preprocessed pose_video_path and face_video_path."
88+
)
89+
pose_video = [Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8)) for _ in range(num_frames)]
90+
face_video = [
91+
Image.fromarray(np.zeros((motion_encoder_size, motion_encoder_size, 3), dtype=np.uint8))
92+
for _ in range(num_frames)
93+
]
94+
95+
background_video = None
96+
mask_video = None
97+
if mode == "replace":
98+
if not background_video_path or not mask_video_path:
99+
raise ValueError("Replace mode requires both `background_video_path` and `mask_video_path`.")
100+
background_video = load_video(background_video_path)[:num_frames]
101+
mask_video = load_video(mask_video_path)[:num_frames]
97102

98103
max_logging.log(
99104
"Wan animate inputs: reference_image=%s, image_size=%s, pose_video_path=%s, face_video_path=%s, %s, %s"
@@ -138,64 +143,33 @@ def run(config):
138143
s0 = time.perf_counter()
139144

140145
# First pass (compile)
141-
videos = pipeline(
142-
image=image,
143-
pose_video=pose_video,
144-
face_video=face_video,
145-
background_video=background_video,
146-
mask_video=mask_video,
147-
prompt=prompt,
148-
negative_prompt=negative_prompt,
149-
height=height,
150-
width=width,
151-
segment_frame_length=animate_settings["segment_frame_length"],
152-
prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"],
153-
motion_encode_batch_size=animate_settings["motion_encode_batch_size"],
154-
guidance_scale=animate_settings["guidance_scale"],
155-
num_inference_steps=config.num_inference_steps,
156-
mode=mode,
157-
)
146+
with jax.profiler.TraceAnnotation("wan_animate_compile_pass"):
147+
videos = pipeline(
148+
image=image,
149+
pose_video=pose_video,
150+
face_video=face_video,
151+
background_video=background_video,
152+
mask_video=mask_video,
153+
prompt=prompt,
154+
negative_prompt=negative_prompt,
155+
height=height,
156+
width=width,
157+
segment_frame_length=animate_settings["segment_frame_length"],
158+
prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"],
159+
motion_encode_batch_size=animate_settings["motion_encode_batch_size"],
160+
guidance_scale=animate_settings["guidance_scale"],
161+
num_inference_steps=config.num_inference_steps,
162+
mode=mode,
163+
)
158164

159165
compile_time = time.perf_counter() - s0
160166
max_logging.log(f"compile_time: {compile_time}")
161167
if writer and jax.process_index() == 0:
162168
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
163169

164170
s0 = time.perf_counter()
165-
videos = pipeline(
166-
image=image,
167-
pose_video=pose_video,
168-
face_video=face_video,
169-
background_video=background_video,
170-
mask_video=mask_video,
171-
prompt=prompt,
172-
negative_prompt=negative_prompt,
173-
height=height,
174-
width=width,
175-
segment_frame_length=animate_settings["segment_frame_length"],
176-
prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"],
177-
motion_encode_batch_size=animate_settings["motion_encode_batch_size"],
178-
guidance_scale=animate_settings["guidance_scale"],
179-
num_inference_steps=config.num_inference_steps,
180-
mode=mode,
181-
)
182-
183-
generation_time = time.perf_counter() - s0
184-
max_logging.log(f"generation_time: {generation_time}")
185-
if writer and jax.process_index() == 0:
186-
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
187-
188-
filename_prefix = "animate_"
189-
os.makedirs(config.output_dir, exist_ok=True)
190-
for i in range(len(videos)):
191-
video_path = os.path.join(config.output_dir, f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
192-
export_to_video(videos[i], video_path, fps=config.fps)
193-
max_logging.log(f"Saved video to {video_path}")
194-
195-
if getattr(config, "enable_profiler", False):
196-
s0 = time.perf_counter()
197-
max_utils.activate_profiler(config)
198-
_ = pipeline(
171+
with jax.profiler.TraceAnnotation("wan_animate_generation_pass"):
172+
videos = pipeline(
199173
image=image,
200174
pose_video=pose_video,
201175
face_video=face_video,
@@ -212,6 +186,40 @@ def run(config):
212186
num_inference_steps=config.num_inference_steps,
213187
mode=mode,
214188
)
189+
190+
generation_time = time.perf_counter() - s0
191+
max_logging.log(f"generation_time: {generation_time}")
192+
if writer and jax.process_index() == 0:
193+
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
194+
195+
filename_prefix = "animate_"
196+
os.makedirs(config.output_dir, exist_ok=True)
197+
for i in range(len(videos)):
198+
video_path = os.path.join(config.output_dir, f"{filename_prefix}wan_output_{config.seed}_{i}.mp4")
199+
export_to_video(videos[i], video_path, fps=config.fps)
200+
max_logging.log(f"Saved video to {video_path}")
201+
202+
if getattr(config, "enable_profiler", False):
203+
s0 = time.perf_counter()
204+
max_utils.activate_profiler(config)
205+
with jax.profiler.TraceAnnotation("wan_animate_profiled_pass"):
206+
_ = pipeline(
207+
image=image,
208+
pose_video=pose_video,
209+
face_video=face_video,
210+
background_video=background_video,
211+
mask_video=mask_video,
212+
prompt=prompt,
213+
negative_prompt=negative_prompt,
214+
height=height,
215+
width=width,
216+
segment_frame_length=animate_settings["segment_frame_length"],
217+
prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"],
218+
motion_encode_batch_size=animate_settings["motion_encode_batch_size"],
219+
guidance_scale=animate_settings["guidance_scale"],
220+
num_inference_steps=config.num_inference_steps,
221+
mode=mode,
222+
)
215223
max_utils.deactivate_profiler(config)
216224
generation_time_with_profiler = time.perf_counter() - s0
217225
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")

0 commit comments

Comments
 (0)