@@ -40,60 +40,65 @@ def run(config):
4040 max_logging .log (f"TensorBoard logs will be written to: { config .tensorboard_dir } " )
4141
4242 load_start = time .perf_counter ()
43- pipeline = WanAnimatePipeline .from_pretrained (config )
43+ with jax .profiler .TraceAnnotation ("wan_animate_load_pipeline" ):
44+ pipeline = WanAnimatePipeline .from_pretrained (config )
4445 load_time = time .perf_counter () - load_start
4546 max_logging .log (f"load_time: { load_time :.1f} s" )
4647
4748 # Setup inputs
48- reference_image_path = getattr (config , "reference_image_path" , "" )
49- if reference_image_path :
50- image = load_image (reference_image_path )
51- reference_image_source = reference_image_path
52- else :
53- image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
54- image = load_image (image_url )
55- reference_image_source = image_url
56-
57- mode = getattr (config , "mode" , "animate" )
58- pose_video_path = getattr (config , "pose_video_path" , "" )
59- face_video_path = getattr (config , "face_video_path" , "" )
60- background_video_path = getattr (config , "background_video_path" , "" )
61- mask_video_path = getattr (config , "mask_video_path" , "" )
62-
63- num_frames = config .num_frames
64- height = config .height
65- width = config .width
66-
67- # face_video needs to match motion_encoder_size (probably 224x224 or 256x256)
68- motion_encoder_size = pipeline .transformer .config .motion_encoder_size
69-
70- if pose_video_path and face_video_path :
71- max_logging .log (
72- f"Loading preprocessed videos from disk. pose_video={ pose_video_path } , face_video={ face_video_path } "
73- )
74- pose_video = load_video (pose_video_path )
75- face_video = load_video (face_video_path )
76- num_frames = min (num_frames , len (pose_video ), len (face_video ))
77- if num_frames == 0 :
78- raise ValueError ("Loaded empty pose/face video. Check preprocessing outputs." )
79- pose_video = pose_video [:num_frames ]
80- face_video = face_video [:num_frames ]
81- else :
82- # Fallback path used for quick smoke tests only.
83- max_logging .log (
84- "No pose/face video paths provided; generating dummy videos for a smoke test only. "
85- "For real outputs provide preprocessed pose_video_path and face_video_path."
86- )
87- pose_video = [Image .fromarray (np .zeros ((height , width , 3 ), dtype = np .uint8 )) for _ in range (num_frames )]
88- face_video = [Image .fromarray (np .zeros ((motion_encoder_size , motion_encoder_size , 3 ), dtype = np .uint8 )) for _ in range (num_frames )]
89-
90- background_video = None
91- mask_video = None
92- if mode == "replace" :
93- if not background_video_path or not mask_video_path :
94- raise ValueError ("Replace mode requires both `background_video_path` and `mask_video_path`." )
95- background_video = load_video (background_video_path )[:num_frames ]
96- mask_video = load_video (mask_video_path )[:num_frames ]
49+ with jax .profiler .TraceAnnotation ("wan_animate_prepare_inputs" ):
50+ reference_image_path = getattr (config , "reference_image_path" , "" )
51+ if reference_image_path :
52+ image = load_image (reference_image_path )
53+ reference_image_source = reference_image_path
54+ else :
55+ image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
56+ image = load_image (image_url )
57+ reference_image_source = image_url
58+
59+ mode = getattr (config , "mode" , "animate" )
60+ pose_video_path = getattr (config , "pose_video_path" , "" )
61+ face_video_path = getattr (config , "face_video_path" , "" )
62+ background_video_path = getattr (config , "background_video_path" , "" )
63+ mask_video_path = getattr (config , "mask_video_path" , "" )
64+
65+ num_frames = config .num_frames
66+ height = config .height
67+ width = config .width
68+
69+ # face_video needs to match motion_encoder_size (probably 224x224 or 256x256)
70+ motion_encoder_size = pipeline .transformer .config .motion_encoder_size
71+
72+ if pose_video_path and face_video_path :
73+ max_logging .log (
74+ f"Loading preprocessed videos from disk. pose_video={ pose_video_path } , face_video={ face_video_path } "
75+ )
76+ pose_video = load_video (pose_video_path )
77+ face_video = load_video (face_video_path )
78+ num_frames = min (num_frames , len (pose_video ), len (face_video ))
79+ if num_frames == 0 :
80+ raise ValueError ("Loaded empty pose/face video. Check preprocessing outputs." )
81+ pose_video = pose_video [:num_frames ]
82+ face_video = face_video [:num_frames ]
83+ else :
84+ # Fallback path used for quick smoke tests only.
85+ max_logging .log (
86+ "No pose/face video paths provided; generating dummy videos for a smoke test only. "
87+ "For real outputs provide preprocessed pose_video_path and face_video_path."
88+ )
89+ pose_video = [Image .fromarray (np .zeros ((height , width , 3 ), dtype = np .uint8 )) for _ in range (num_frames )]
90+ face_video = [
91+ Image .fromarray (np .zeros ((motion_encoder_size , motion_encoder_size , 3 ), dtype = np .uint8 ))
92+ for _ in range (num_frames )
93+ ]
94+
95+ background_video = None
96+ mask_video = None
97+ if mode == "replace" :
98+ if not background_video_path or not mask_video_path :
99+ raise ValueError ("Replace mode requires both `background_video_path` and `mask_video_path`." )
100+ background_video = load_video (background_video_path )[:num_frames ]
101+ mask_video = load_video (mask_video_path )[:num_frames ]
97102
98103 max_logging .log (
99104 "Wan animate inputs: reference_image=%s, image_size=%s, pose_video_path=%s, face_video_path=%s, %s, %s"
@@ -138,64 +143,33 @@ def run(config):
138143 s0 = time .perf_counter ()
139144
140145 # First pass (compile)
141- videos = pipeline (
142- image = image ,
143- pose_video = pose_video ,
144- face_video = face_video ,
145- background_video = background_video ,
146- mask_video = mask_video ,
147- prompt = prompt ,
148- negative_prompt = negative_prompt ,
149- height = height ,
150- width = width ,
151- segment_frame_length = animate_settings ["segment_frame_length" ],
152- prev_segment_conditioning_frames = animate_settings ["prev_segment_conditioning_frames" ],
153- motion_encode_batch_size = animate_settings ["motion_encode_batch_size" ],
154- guidance_scale = animate_settings ["guidance_scale" ],
155- num_inference_steps = config .num_inference_steps ,
156- mode = mode ,
157- )
146+ with jax .profiler .TraceAnnotation ("wan_animate_compile_pass" ):
147+ videos = pipeline (
148+ image = image ,
149+ pose_video = pose_video ,
150+ face_video = face_video ,
151+ background_video = background_video ,
152+ mask_video = mask_video ,
153+ prompt = prompt ,
154+ negative_prompt = negative_prompt ,
155+ height = height ,
156+ width = width ,
157+ segment_frame_length = animate_settings ["segment_frame_length" ],
158+ prev_segment_conditioning_frames = animate_settings ["prev_segment_conditioning_frames" ],
159+ motion_encode_batch_size = animate_settings ["motion_encode_batch_size" ],
160+ guidance_scale = animate_settings ["guidance_scale" ],
161+ num_inference_steps = config .num_inference_steps ,
162+ mode = mode ,
163+ )
158164
159165 compile_time = time .perf_counter () - s0
160166 max_logging .log (f"compile_time: { compile_time } " )
161167 if writer and jax .process_index () == 0 :
162168 writer .add_scalar ("inference/compile_time" , compile_time , global_step = 0 )
163169
164170 s0 = time .perf_counter ()
165- videos = pipeline (
166- image = image ,
167- pose_video = pose_video ,
168- face_video = face_video ,
169- background_video = background_video ,
170- mask_video = mask_video ,
171- prompt = prompt ,
172- negative_prompt = negative_prompt ,
173- height = height ,
174- width = width ,
175- segment_frame_length = animate_settings ["segment_frame_length" ],
176- prev_segment_conditioning_frames = animate_settings ["prev_segment_conditioning_frames" ],
177- motion_encode_batch_size = animate_settings ["motion_encode_batch_size" ],
178- guidance_scale = animate_settings ["guidance_scale" ],
179- num_inference_steps = config .num_inference_steps ,
180- mode = mode ,
181- )
182-
183- generation_time = time .perf_counter () - s0
184- max_logging .log (f"generation_time: { generation_time } " )
185- if writer and jax .process_index () == 0 :
186- writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
187-
188- filename_prefix = "animate_"
189- os .makedirs (config .output_dir , exist_ok = True )
190- for i in range (len (videos )):
191- video_path = os .path .join (config .output_dir , f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4" )
192- export_to_video (videos [i ], video_path , fps = config .fps )
193- max_logging .log (f"Saved video to { video_path } " )
194-
195- if getattr (config , "enable_profiler" , False ):
196- s0 = time .perf_counter ()
197- max_utils .activate_profiler (config )
198- _ = pipeline (
171+ with jax .profiler .TraceAnnotation ("wan_animate_generation_pass" ):
172+ videos = pipeline (
199173 image = image ,
200174 pose_video = pose_video ,
201175 face_video = face_video ,
@@ -212,6 +186,40 @@ def run(config):
212186 num_inference_steps = config .num_inference_steps ,
213187 mode = mode ,
214188 )
189+
190+ generation_time = time .perf_counter () - s0
191+ max_logging .log (f"generation_time: { generation_time } " )
192+ if writer and jax .process_index () == 0 :
193+ writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
194+
195+ filename_prefix = "animate_"
196+ os .makedirs (config .output_dir , exist_ok = True )
197+ for i in range (len (videos )):
198+ video_path = os .path .join (config .output_dir , f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4" )
199+ export_to_video (videos [i ], video_path , fps = config .fps )
200+ max_logging .log (f"Saved video to { video_path } " )
201+
202+ if getattr (config , "enable_profiler" , False ):
203+ s0 = time .perf_counter ()
204+ max_utils .activate_profiler (config )
205+ with jax .profiler .TraceAnnotation ("wan_animate_profiled_pass" ):
206+ _ = pipeline (
207+ image = image ,
208+ pose_video = pose_video ,
209+ face_video = face_video ,
210+ background_video = background_video ,
211+ mask_video = mask_video ,
212+ prompt = prompt ,
213+ negative_prompt = negative_prompt ,
214+ height = height ,
215+ width = width ,
216+ segment_frame_length = animate_settings ["segment_frame_length" ],
217+ prev_segment_conditioning_frames = animate_settings ["prev_segment_conditioning_frames" ],
218+ motion_encode_batch_size = animate_settings ["motion_encode_batch_size" ],
219+ guidance_scale = animate_settings ["guidance_scale" ],
220+ num_inference_steps = config .num_inference_steps ,
221+ mode = mode ,
222+ )
215223 max_utils .deactivate_profiler (config )
216224 generation_time_with_profiler = time .perf_counter () - s0
217225 max_logging .log (f"generation_time_with_profiler: { generation_time_with_profiler } " )
0 commit comments