File tree Expand file tree Collapse file tree
src/maxdiffusion/pipelines/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -171,9 +171,20 @@ def __call__(
171171 )
172172
173173 image_tensor = self .video_processor .preprocess (image , height = height , width = width )
174+ if image_tensor .ndim == 3 :
175+ image_tensor = image_tensor [None , ...]
174176 last_image_tensor = None
175177 if last_image :
176178 last_image_tensor = self .video_processor .preprocess (last_image , height = height , width = width )
179+ if last_image_tensor .ndim == 3 :
180+ last_image_tensor = last_image_tensor [None , ...] # Add batch dimension
181+
182+ if effective_batch_size > 1 :
183+ image_tensor = jnp .repeat (image_tensor , effective_batch_size , axis = 0 )
184+ if last_image_tensor is not None :
185+ last_image_tensor = jnp .repeat (last_image_tensor , effective_batch_size , axis = 0 )
186+
187+
177188
178189 if rng is None :
179190 rng = jax .random .key (self .config .seed )
You can’t perform that action at this time.
0 commit comments