@@ -164,22 +164,19 @@ def __call__(
164164 prompt , image , negative_prompt , num_videos_per_prompt , max_sequence_length ,
165165 prompt_embeds , negative_prompt_embeds , image_embeds , last_image
166166 )
167-
168- image_tensor = self .video_processor .preprocess (image , height = height , width = width )
169- if image_tensor .ndim == 3 :
170- image_tensor = image_tensor [None , ...]
171- last_image_tensor = None
172- if last_image :
173- last_image_tensor = self .video_processor .preprocess (last_image , height = height , width = width )
174- if last_image_tensor .ndim == 3 :
175- last_image_tensor = last_image_tensor [None , ...] # Add batch dimension
176-
177- if effective_batch_size > 1 :
178- image_tensor = jnp .repeat (image_tensor , effective_batch_size , axis = 0 )
179- if last_image_tensor is not None :
180- last_image_tensor = jnp .repeat (last_image_tensor , effective_batch_size , axis = 0 )
181-
182-
167+ def _process_image_input (img_input , height , width , num_videos_per_prompt ):
168+ if img_input is None :
169+ return None
170+ tensor = self .video_processor .preprocess (img_input , height = height , width = width )
171+ jax_array = jnp .array (tensor .cpu ().numpy ())
172+ if jax_array .ndim == 3 :
173+ jax_array = jax_array [None , ...] # Add batch dimension
174+ if num_videos_per_prompt > 1 :
175+ jax_array = jnp .repeat (jax_array , num_videos_per_prompt , axis = 0 )
176+ return jax_array
177+
178+ image_tensor = _process_image_input (image , height , width , effective_batch_size )
179+ last_image_tensor = _process_image_input (last_image , height , width , effective_batch_size )
183180
184181 if rng is None :
185182 rng = jax .random .key (self .config .seed )
0 commit comments