Skip to content

Commit d153740

Browse files
committed
Fix for multiple videos
1 parent 7c5726d commit d153740

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,19 @@ def __call__(
176176
prompt_embeds, negative_prompt_embeds, image_embeds, last_image
177177
)
178178

179-
image_tensor = self.video_processor.preprocess(image, height=height, width=width)
180-
last_image_tensor = None
181-
if last_image:
182-
last_image_tensor = self.video_processor.preprocess(last_image, height=height, width=width)
179+
def _process_image_input(img_input, height, width, num_videos_per_prompt):
180+
if img_input is None:
181+
return None
182+
tensor = self.video_processor.preprocess(img_input, height=height, width=width)
183+
jax_array = jnp.array(tensor.cpu().numpy())
184+
if jax_array.ndim == 3:
185+
jax_array = jax_array[None, ...] # Add batch dimension
186+
if num_videos_per_prompt > 1:
187+
jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0)
188+
return jax_array
189+
190+
image_tensor = _process_image_input(image, height, width, effective_batch_size)
191+
last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size)
183192

184193
if rng is None:
185194
rng = jax.random.key(self.config.seed)

0 commit comments

Comments
 (0)