@@ -601,7 +601,7 @@ def _prepare_model_inputs_i2v(
601601 ):
602602 if prompt is not None and isinstance (prompt , str ):
603603 prompt = [prompt ]
604- batch_size = len (prompt ) if prompt is not None else prompt_embeds .shape [0 ]
604+ batch_size = len (prompt ) if prompt is not None else prompt_embeds .shape [0 ] // num_videos_per_prompt
605605 effective_batch_size = batch_size * num_videos_per_prompt
606606
607607 # 1. Encode Prompts
@@ -632,11 +632,13 @@ def _prepare_model_inputs_i2v(
632632 if negative_prompt_embeds is not None :
633633 negative_prompt_embeds = negative_prompt_embeds .astype (transformer_dtype )
634634
635- data_sharding = NamedSharding (self .mesh , P ())
635+ prompt_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
636+ image_sharding = NamedSharding (self .mesh , P ())
636637
637- prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
638- negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
639- image_embeds = jax .device_put (image_embeds , data_sharding )
638+
639+ prompt_embeds = jax .device_put (prompt_embeds , prompt_sharding )
640+ negative_prompt_embeds = jax .device_put (negative_prompt_embeds , prompt_sharding )
641+ image_embeds = jax .device_put (image_embeds , image_sharding )
640642
641643 return prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size
642644
0 commit comments