Skip to content

Commit 9eb359d

Browse files
committed
image_embeds sharded
1 parent 9e3e998 commit 9eb359d

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,14 +637,12 @@ def _prepare_model_inputs_i2v(
637637
if negative_prompt_embeds is not None:
638638
negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype)
639639

640-
prompt_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
641-
image_sharding = NamedSharding(self.mesh, P())
642-
print(f"[DEBUG PREP] prompt_sharding spec: {self.config.data_sharding}")
643-
print(f"[DEBUG PREP] image_sharding spec: () - Replicated")
644-
645-
prompt_embeds = jax.device_put(prompt_embeds, prompt_sharding)
646-
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, prompt_sharding)
647-
image_embeds = jax.device_put(image_embeds, image_sharding)
640+
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
641+
print(f"[DEBUG PREP] data_sharding spec: {self.config.data_sharding}")
642+
643+
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
644+
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
645+
image_embeds = jax.device_put(image_embeds, data_sharding)
648646

649647
print(f"[DEBUG PREP] SHARDED prompt_embeds.shape: {prompt_embeds.shape}")
650648
print(f"[DEBUG PREP] SHARDED image_embeds.shape: {image_embeds.shape}")

0 commit comments

Comments
 (0)