File tree Expand file tree Collapse file tree
src/maxdiffusion/pipelines/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -637,14 +637,12 @@ def _prepare_model_inputs_i2v(
637637 if negative_prompt_embeds is not None :
638638 negative_prompt_embeds = negative_prompt_embeds .astype (transformer_dtype )
639639
640- prompt_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
641- image_sharding = NamedSharding (self .mesh , P ())
642- print (f"[DEBUG PREP] prompt_sharding spec: { self .config .data_sharding } " )
643- print (f"[DEBUG PREP] image_sharding spec: () - Replicated" )
644-
645- prompt_embeds = jax .device_put (prompt_embeds , prompt_sharding )
646- negative_prompt_embeds = jax .device_put (negative_prompt_embeds , prompt_sharding )
647- image_embeds = jax .device_put (image_embeds , image_sharding )
640+ data_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
641+ print (f"[DEBUG PREP] data_sharding spec: { self .config .data_sharding } " )
642+
643+ prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
644+ negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
645+ image_embeds = jax .device_put (image_embeds , data_sharding )
648646
649647 print (f"[DEBUG PREP] SHARDED prompt_embeds.shape: { prompt_embeds .shape } " )
650648 print (f"[DEBUG PREP] SHARDED image_embeds.shape: { image_embeds .shape } " )
You can’t perform that action at this time.
0 commit comments