Skip to content

Commit 13075ad

Browse files
committed
changes in prepare i2v model inputs
1 parent 6c49ce7 commit 13075ad

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def _prepare_model_inputs_i2v(
601601
):
602602
if prompt is not None and isinstance(prompt, str):
603603
prompt = [prompt]
604-
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
604+
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt
605605
effective_batch_size = batch_size * num_videos_per_prompt
606606

607607
# 1. Encode Prompts
@@ -632,11 +632,13 @@ def _prepare_model_inputs_i2v(
632632
if negative_prompt_embeds is not None:
633633
negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype)
634634

635-
data_sharding = NamedSharding(self.mesh, P())
635+
prompt_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
636+
image_sharding = NamedSharding(self.mesh, P())
636637

637-
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
638-
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
639-
image_embeds = jax.device_put(image_embeds, data_sharding)
638+
639+
prompt_embeds = jax.device_put(prompt_embeds, prompt_sharding)
640+
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, prompt_sharding)
641+
image_embeds = jax.device_put(image_embeds, image_sharding)
640642

641643
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size
642644

0 commit comments

Comments
 (0)