@@ -633,7 +633,7 @@ def _prepare_model_inputs_i2v(
633633 effective_batch_size = batch_size * num_videos_per_prompt
634634
635635 # 1. Encode Prompts
636- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
636+ prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask = self .encode_prompt (
637637 prompt = prompt ,
638638 negative_prompt = negative_prompt ,
639639 num_videos_per_prompt = num_videos_per_prompt ,
@@ -677,9 +677,12 @@ def _prepare_model_inputs_i2v(
677677
678678 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
679679 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
680+ text_attention_mask = jax .device_put (text_attention_mask , data_sharding )
681+ negative_text_attention_mask = jax .device_put (negative_text_attention_mask , data_sharding )
680682 image_embeds = jax .device_put (image_embeds , data_sharding )
681683
682- return prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size
684+ return (prompt_embeds , negative_prompt_embeds , text_attention_mask ,
685+ negative_text_attention_mask , image_embeds , effective_batch_size )
683686
684687
685688 def _prepare_model_inputs (
0 commit comments