@@ -164,17 +164,11 @@ def __call__(
164164 max_logging .log (f"Adjusted num_frames to: { num_frames } " )
165165 num_frames = max (num_frames , 1 )
166166
167- prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size = self ._prepare_model_inputs_i2v (
168- prompt ,
169- image ,
170- negative_prompt ,
171- num_videos_per_prompt ,
172- max_sequence_length ,
173- prompt_embeds ,
174- negative_prompt_embeds ,
175- image_embeds ,
176- last_image ,
177- )
167+ # Calculate batch size early for prepare_latents dispatch overlap
168+ if prompt is not None and isinstance (prompt , str ):
169+ prompt = [prompt ]
170+ batch_size = len (prompt ) if prompt is not None else prompt_embeds .shape [0 ] // num_videos_per_prompt
171+ effective_batch_size = batch_size * num_videos_per_prompt
178172
179173 def _process_image_input (img_input , height , width , num_videos_per_prompt ):
180174 if img_input is None :
@@ -187,26 +181,44 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
187181 jax_array = jnp .repeat (jax_array , num_videos_per_prompt , axis = 0 )
188182 return jax_array
189183
184+ # 1. Dispatch VAE (Step 3) - Prepare Latents
190185 image_tensor = _process_image_input (image , height , width , effective_batch_size )
191186 last_image_tensor = _process_image_input (last_image , height , width , effective_batch_size )
192187
193188 if rng is None :
194189 rng = jax .random .key (self .config .seed )
195190 latents_rng , inference_rng = jax .random .split (rng )
196191
192+ # Use config.activations_dtype since image_embeds is not yet available
193+ latents_dtype = self .config .activations_dtype
194+
197195 latents , condition , first_frame_mask = self .prepare_latents (
198196 image = image_tensor ,
199197 batch_size = effective_batch_size ,
200198 height = height ,
201199 width = width ,
202200 num_frames = num_frames ,
203- dtype = image_embeds . dtype ,
201+ dtype = latents_dtype ,
204202 rng = latents_rng ,
205203 latents = latents ,
206204 last_image = last_image_tensor ,
207205 num_videos_per_prompt = num_videos_per_prompt ,
208206 )
209207
208+ # 2. Dispatch Text & CLIP (Steps 1 & 2)
209+ # This might block on CPU/Text encoding, but VAE is already dispatched!
210+ prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size = self ._prepare_model_inputs_i2v (
211+ prompt ,
212+ image ,
213+ negative_prompt ,
214+ num_videos_per_prompt ,
215+ max_sequence_length ,
216+ prompt_embeds ,
217+ negative_prompt_embeds ,
218+ image_embeds ,
219+ last_image ,
220+ )
221+
210222 scheduler_state = self .scheduler .set_timesteps (
211223 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
212224 )
0 commit comments