@@ -176,10 +176,19 @@ def __call__(
176176 prompt_embeds , negative_prompt_embeds , image_embeds , last_image
177177 )
178178
179- image_tensor = self .video_processor .preprocess (image , height = height , width = width )
180- last_image_tensor = None
181- if last_image :
182- last_image_tensor = self .video_processor .preprocess (last_image , height = height , width = width )
179+ def _process_image_input (img_input , height , width , num_videos_per_prompt ):
180+ if img_input is None :
181+ return None
182+ tensor = self .video_processor .preprocess (img_input , height = height , width = width )
183+ jax_array = jnp .array (tensor .cpu ().numpy ())
184+ if jax_array .ndim == 3 :
185+ jax_array = jax_array [None , ...] # Add batch dimension
186+ if num_videos_per_prompt > 1 :
187+ jax_array = jnp .repeat (jax_array , num_videos_per_prompt , axis = 0 )
188+ return jax_array
189+
190+ image_tensor = _process_image_input (image , height , width , effective_batch_size )
191+ last_image_tensor = _process_image_input (last_image , height , width , effective_batch_size )
183192
184193 if rng is None :
185194 rng = jax .random .key (self .config .seed )
0 commit comments