Skip to content

Commit 8e63fe0

Browse files
committed
Fix for multiple videos
1 parent 7543d00 commit 8e63fe0

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,20 @@ def __call__(
171171
)
172172

173173
image_tensor = self.video_processor.preprocess(image, height=height, width=width)
174+
if image_tensor.ndim == 3:
175+
image_tensor = image_tensor[None, ...]
174176
last_image_tensor = None
175177
if last_image:
176178
last_image_tensor = self.video_processor.preprocess(last_image, height=height, width=width)
179+
if last_image_tensor.ndim == 3:
180+
last_image_tensor = last_image_tensor[None, ...] # Add batch dimension
181+
182+
if effective_batch_size > 1:
183+
image_tensor = jnp.repeat(image_tensor, effective_batch_size, axis=0)
184+
if last_image_tensor is not None:
185+
last_image_tensor = jnp.repeat(last_image_tensor, effective_batch_size, axis=0)
186+
187+
177188

178189
if rng is None:
179190
rng = jax.random.key(self.config.seed)

0 commit comments

Comments
 (0)