@@ -832,14 +832,43 @@ def encode_prompt(
832832 batch_size = prompt_embeds .shape [0 ]
833833
834834 if prompt_embeds is None :
835- prompt_embeds , prompt_attention_mask = self ._get_gemma_prompt_embeds (
836- prompt = prompt ,
837- num_videos_per_prompt = num_videos_per_prompt ,
838- max_sequence_length = max_sequence_length ,
839- scale_factor = scale_factor ,
840- dtype = dtype ,
841- )
842-
835+ if do_classifier_free_guidance and negative_prompt_embeds is None :
836+ negative_prompt = negative_prompt or ""
837+ negative_prompt = [negative_prompt ] * batch_size if isinstance (negative_prompt , str ) else negative_prompt
838+
839+ if isinstance (prompt , str ):
840+ prompt = [prompt ]
841+
842+ combined_prompts = prompt + negative_prompt
843+
844+ combined_embeds , combined_mask = self ._get_gemma_prompt_embeds (
845+ prompt = combined_prompts ,
846+ num_videos_per_prompt = num_videos_per_prompt ,
847+ max_sequence_length = max_sequence_length ,
848+ scale_factor = scale_factor ,
849+ dtype = dtype ,
850+ )
851+
852+ split_idx = batch_size * num_videos_per_prompt
853+
854+ if isinstance (combined_embeds , list ):
855+ prompt_embeds = [state [:split_idx ] for state in combined_embeds ]
856+ negative_prompt_embeds = [state [split_idx :] for state in combined_embeds ]
857+ else :
858+ prompt_embeds = combined_embeds [:split_idx ]
859+ negative_prompt_embeds = combined_embeds [split_idx :]
860+
861+ prompt_attention_mask = combined_mask [:split_idx ]
862+ negative_prompt_attention_mask = combined_mask [split_idx :]
863+ else :
864+ prompt_embeds , prompt_attention_mask = self ._get_gemma_prompt_embeds (
865+ prompt = prompt ,
866+ num_videos_per_prompt = num_videos_per_prompt ,
867+ max_sequence_length = max_sequence_length ,
868+ scale_factor = scale_factor ,
869+ dtype = dtype ,
870+ )
871+
843872 if do_classifier_free_guidance and negative_prompt_embeds is None :
844873 negative_prompt = negative_prompt or ""
845874 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
0 commit comments