File tree Expand file tree Collapse file tree
src/maxdiffusion/pipelines/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -637,12 +637,14 @@ def _prepare_model_inputs_i2v(
637637 if negative_prompt_embeds is not None :
638638 negative_prompt_embeds = negative_prompt_embeds .astype (transformer_dtype )
639639
640- data_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
641- print (f"[DEBUG PREP] data_sharding spec: { self .config .data_sharding } " )
642-
643- prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
644- negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
645- image_embeds = jax .device_put (image_embeds , data_sharding )
640+ prompt_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
641+ image_sharding = NamedSharding (self .mesh , P ())
642+ print (f"[DEBUG PREP] prompt_sharding spec: { self .config .data_sharding } " )
643+ print (f"[DEBUG PREP] image_sharding spec: () - Replicated" )
644+
645+ prompt_embeds = jax .device_put (prompt_embeds , prompt_sharding )
646+ negative_prompt_embeds = jax .device_put (negative_prompt_embeds , prompt_sharding )
647+ image_embeds = jax .device_put (image_embeds , image_sharding )
646648
647649 print (f"[DEBUG PREP] SHARDED prompt_embeds.shape: { prompt_embeds .shape } " )
648650 print (f"[DEBUG PREP] SHARDED image_embeds.shape: { image_embeds .shape } " )
You can’t perform that action at this time.
0 commit comments