Skip to content

Commit 7c5726d

Browse files
committed
Fix for multiple videos
1 parent 4580591 commit 7c5726d

1 file changed

Lines changed: 13 additions & 16 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,22 +164,19 @@ def __call__(
164164
prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length,
165165
prompt_embeds, negative_prompt_embeds, image_embeds, last_image
166166
)
167-
168-
image_tensor = self.video_processor.preprocess(image, height=height, width=width)
169-
if image_tensor.ndim == 3:
170-
image_tensor = image_tensor[None, ...]
171-
last_image_tensor = None
172-
if last_image:
173-
last_image_tensor = self.video_processor.preprocess(last_image, height=height, width=width)
174-
if last_image_tensor.ndim == 3:
175-
last_image_tensor = last_image_tensor[None, ...] # Add batch dimension
176-
177-
if effective_batch_size > 1:
178-
image_tensor = jnp.repeat(image_tensor, effective_batch_size, axis=0)
179-
if last_image_tensor is not None:
180-
last_image_tensor = jnp.repeat(last_image_tensor, effective_batch_size, axis=0)
181-
182-
167+
def _process_image_input(img_input, height, width, num_videos_per_prompt):
168+
if img_input is None:
169+
return None
170+
tensor = self.video_processor.preprocess(img_input, height=height, width=width)
171+
jax_array = jnp.array(tensor.cpu().numpy())
172+
if jax_array.ndim == 3:
173+
jax_array = jax_array[None, ...] # Add batch dimension
174+
if num_videos_per_prompt > 1:
175+
jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0)
176+
return jax_array
177+
178+
image_tensor = _process_image_input(image, height, width, effective_batch_size)
179+
last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size)
183180

184181
if rng is None:
185182
rng = jax.random.key(self.config.seed)

0 commit comments

Comments
 (0)