11# Copyright 2026 Google LLC
22# Licensed under the Apache License, Version 2.0 (the "License");
33
4- import jax
4+ """Wan Animate inference entrypoint."""
5+
56import os
67import time
8+
79from absl import app
8- from maxdiffusion import pyconfig , max_logging , max_utils
9- from maxdiffusion .train_utils import transformer_engine_context
10- from maxdiffusion .utils import export_to_video
11- from maxdiffusion .utils .loading_utils import load_image , load_video
1210import flax
13- from maxdiffusion . pipelines . wan . wan_pipeline_animate import WanAnimatePipeline
11+ import jax
1412import numpy as np
1513from PIL import Image
1614
15+ from maxdiffusion import max_logging , max_utils , pyconfig
16+ from maxdiffusion .pipelines .wan .wan_pipeline_animate import WanAnimatePipeline
17+ from maxdiffusion .train_utils import transformer_engine_context
18+ from maxdiffusion .utils import export_to_video
19+ from maxdiffusion .utils .loading_utils import load_image , load_video
20+
1721jax .config .update ("jax_use_shardy_partitioner" , True )
1822
1923
2024def _get_animate_inference_settings (config ):
2125 """Resolve animate-specific inference settings with upstream defaults."""
2226 return {
2327 "segment_frame_length" : getattr (config , "segment_frame_length" , 77 ),
24- "prev_segment_conditioning_frames" : getattr (config , "prev_segment_conditioning_frames" , 1 ),
28+ "prev_segment_conditioning_frames" : getattr (config , "prev_segment_conditioning_frames" , 5 ),
2529 "motion_encode_batch_size" : getattr (config , "motion_encode_batch_size" , None ),
2630 "guidance_scale" : getattr (config , "animate_guidance_scale" , 1.0 ),
2731 }
@@ -35,6 +39,7 @@ def _frame_summary(name, frames):
3539
3640
3741def run (config ):
42+ """Run Wan Animate inference and write the generated videos to disk."""
3843 writer = max_utils .initialize_summary_writer (config )
3944 if jax .process_index () == 0 and writer :
4045 max_logging .log (f"TensorBoard logs will be written to: { config .tensorboard_dir } " )
@@ -68,9 +73,7 @@ def run(config):
6873 motion_encoder_size = pipeline .transformer .config .motion_encoder_size
6974
7075 if pose_video_path and face_video_path :
71- max_logging .log (
72- f"Loading preprocessed videos from disk. pose_video={ pose_video_path } , face_video={ face_video_path } "
73- )
76+ max_logging .log (f"Loading preprocessed videos from disk. pose_video={ pose_video_path } , face_video={ face_video_path } " )
7477 pose_video = load_video (pose_video_path )
7578 face_video = load_video (face_video_path )
7679 num_frames = min (num_frames , len (pose_video ), len (face_video ))
@@ -85,7 +88,9 @@ def run(config):
8588 "For real outputs provide preprocessed pose_video_path and face_video_path."
8689 )
8790 pose_video = [Image .fromarray (np .zeros ((height , width , 3 ), dtype = np .uint8 )) for _ in range (num_frames )]
88- face_video = [Image .fromarray (np .zeros ((motion_encoder_size , motion_encoder_size , 3 ), dtype = np .uint8 )) for _ in range (num_frames )]
91+ face_video = [
92+ Image .fromarray (np .zeros ((motion_encoder_size , motion_encoder_size , 3 ), dtype = np .uint8 )) for _ in range (num_frames )
93+ ]
8994
9095 background_video = None
9196 mask_video = None
@@ -96,47 +101,37 @@ def run(config):
96101 mask_video = load_video (mask_video_path )[:num_frames ]
97102
98103 max_logging .log (
99- "Wan animate inputs: reference_image=%s, image_size=%s, pose_video_path=%s, face_video_path=%s, %s, %s"
100- % (
101- reference_image_source ,
102- getattr (image , "size" , None ),
103- pose_video_path or "<dummy>" ,
104- face_video_path or "<dummy>" ,
105- _frame_summary ("pose" , pose_video ),
106- _frame_summary ("face" , face_video ),
107- )
104+ "Wan animate inputs: "
105+ f"reference_image={ reference_image_source } , "
106+ f"image_size={ getattr (image , 'size' , None )} , "
107+ f"pose_video_path={ pose_video_path or '<dummy>' } , "
108+ f"face_video_path={ face_video_path or '<dummy>' } , "
109+ f"{ _frame_summary ('pose' , pose_video )} , "
110+ f"{ _frame_summary ('face' , face_video )} "
108111 )
109112 if mode == "replace" :
110113 max_logging .log (
111- "Wan replace inputs: background_video_path=%s, mask_video_path=%s, %s, %s"
112- % (
113- background_video_path ,
114- mask_video_path ,
115- _frame_summary ("background" , background_video ),
116- _frame_summary ("mask" , mask_video ),
117- )
114+ "Wan replace inputs: "
115+ f"background_video_path={ background_video_path } , "
116+ f"mask_video_path={ mask_video_path } , "
117+ f"{ _frame_summary ('background' , background_video )} , "
118+ f"{ _frame_summary ('mask' , mask_video )} "
118119 )
119120
120121 animate_settings = _get_animate_inference_settings (config )
121122 prompt = config .prompt
122123 negative_prompt = config .negative_prompt if animate_settings ["guidance_scale" ] > 1.0 else None
123124
124125 max_logging .log (
125- "Num steps: %s, height: %s, width: %s, frames: %s, segment_frame_length: %s, "
126- "prev_segment_conditioning_frames: %s, guidance_scale: %s"
127- % (
128- config .num_inference_steps ,
129- height ,
130- width ,
131- num_frames ,
132- animate_settings ["segment_frame_length" ],
133- animate_settings ["prev_segment_conditioning_frames" ],
134- animate_settings ["guidance_scale" ],
135- )
126+ "Num steps: "
127+ f"{ config .num_inference_steps } , height: { height } , width: { width } , frames: { num_frames } , "
128+ f"segment_frame_length: { animate_settings ['segment_frame_length' ]} , "
129+ f"prev_segment_conditioning_frames: { animate_settings ['prev_segment_conditioning_frames' ]} , "
130+ f"guidance_scale: { animate_settings ['guidance_scale' ]} "
136131 )
137132
138133 s0 = time .perf_counter ()
139-
134+
140135 # First pass (compile)
141136 videos = pipeline (
142137 image = image ,
@@ -155,7 +150,7 @@ def run(config):
155150 num_inference_steps = config .num_inference_steps ,
156151 mode = mode ,
157152 )
158-
153+
159154 compile_time = time .perf_counter () - s0
160155 max_logging .log (f"compile_time: { compile_time } " )
161156 if writer and jax .process_index () == 0 :
@@ -179,17 +174,17 @@ def run(config):
179174 num_inference_steps = config .num_inference_steps ,
180175 mode = mode ,
181176 )
182-
177+
183178 generation_time = time .perf_counter () - s0
184179 max_logging .log (f"generation_time: { generation_time } " )
185180 if writer and jax .process_index () == 0 :
186181 writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
187182
188183 filename_prefix = "animate_"
189184 os .makedirs (config .output_dir , exist_ok = True )
190- for i in range ( len ( videos ) ):
185+ for i , video in enumerate ( videos ):
191186 video_path = os .path .join (config .output_dir , f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4" )
192- export_to_video (videos [ i ] , video_path , fps = config .fps )
187+ export_to_video (video , video_path , fps = config .fps )
193188 max_logging .log (f"Saved video to { video_path } " )
194189
195190 if getattr (config , "enable_profiler" , False ):
@@ -220,6 +215,7 @@ def run(config):
220215
221216 return videos
222217
218+
223219def main (argv ) -> None :
224220 pyconfig .initialize (argv )
225221 try :
@@ -228,6 +224,7 @@ def main(argv) -> None:
228224 pass
229225 run (pyconfig .config )
230226
227+
231228if __name__ == "__main__" :
232229 with transformer_engine_context ():
233230 app .run (main )
0 commit comments