Skip to content

Commit 9e3e998

Browse files
committed
image_embeds replicated
1 parent 3fb489a commit 9e3e998

1 file changed

Lines changed: 8 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -637,12 +637,14 @@ def _prepare_model_inputs_i2v(
637637
if negative_prompt_embeds is not None:
638638
negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype)
639639

640-
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
641-
print(f"[DEBUG PREP] data_sharding spec: {self.config.data_sharding}")
642-
643-
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
644-
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
645-
image_embeds = jax.device_put(image_embeds, data_sharding)
640+
prompt_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
641+
image_sharding = NamedSharding(self.mesh, P())
642+
print(f"[DEBUG PREP] prompt_sharding spec: {self.config.data_sharding}")
643+
print(f"[DEBUG PREP] image_sharding spec: () - Replicated")
644+
645+
prompt_embeds = jax.device_put(prompt_embeds, prompt_sharding)
646+
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, prompt_sharding)
647+
image_embeds = jax.device_put(image_embeds, image_sharding)
646648

647649
print(f"[DEBUG PREP] SHARDED prompt_embeds.shape: {prompt_embeds.shape}")
648650
print(f"[DEBUG PREP] SHARDED image_embeds.shape: {image_embeds.shape}")

0 commit comments

Comments
 (0)