Skip to content

Commit 9035b2f

Browse files
committed
sharding for image and prompt embeds made same
1 parent 70164a8 commit 9035b2f

1 file changed

Lines changed: 4 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -632,13 +632,11 @@ def _prepare_model_inputs_i2v(
632632
if negative_prompt_embeds is not None:
633633
negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype)
634634

635-
prompt_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
636-
image_sharding = NamedSharding(self.mesh, P())
635+
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
637636

638-
639-
prompt_embeds = jax.device_put(prompt_embeds, prompt_sharding)
640-
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, prompt_sharding)
641-
image_embeds = jax.device_put(image_embeds, image_sharding)
637+
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
638+
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
639+
image_embeds = jax.device_put(image_embeds, data_sharding)
642640

643641
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size
644642

0 commit comments

Comments
 (0)