2828from flax import nnx
2929from flax .linen import partitioning as nn_partitioning
3030from transformers import AutoTokenizer , GemmaTokenizer , GemmaTokenizerFast , Gemma3ForConditionalGeneration
31+ from maxdiffusion .tpu_utils import get_tpu_type , TpuType
3132import qwix
3233from ...utils import logging
3334from ...schedulers import FlaxFlowMatchScheduler
@@ -828,36 +829,41 @@ def encode_prompt(
828829 else :
829830 batch_size = prompt_embeds .shape [0 ]
830831
831- if prompt_embeds is None :
832- if do_classifier_free_guidance and negative_prompt_embeds is None :
833- negative_prompt = negative_prompt or ""
834- negative_prompt = [negative_prompt ] * batch_size if isinstance (negative_prompt , str ) else negative_prompt
832+ tpu_type = get_tpu_type ()
833+ # Batching text encoder gives better results on Ironwood (v7x) but poor on Trillium (v6e)
834+ use_batched_text_encoder = tpu_type == TpuType .TPU_7X
835835
836- if isinstance (prompt , str ):
837- prompt = [prompt ]
836+ if use_batched_text_encoder and prompt_embeds is None and do_classifier_free_guidance and negative_prompt_embeds is None :
837+ negative_prompt = negative_prompt or ""
838+ negative_prompt = [negative_prompt ] * batch_size if isinstance (negative_prompt , str ) else negative_prompt
838839
839- combined_prompts = prompt + negative_prompt
840+ if isinstance (prompt , str ):
841+ prompt = [prompt ]
840842
841- combined_embeds , combined_mask = self ._get_gemma_prompt_embeds (
842- prompt = combined_prompts ,
843- num_videos_per_prompt = num_videos_per_prompt ,
844- max_sequence_length = max_sequence_length ,
845- scale_factor = scale_factor ,
846- dtype = dtype ,
847- )
843+ combined_prompts = prompt + negative_prompt
848844
849- split_idx = batch_size * num_videos_per_prompt
845+ combined_embeds , combined_mask = self ._get_gemma_prompt_embeds (
846+ prompt = combined_prompts ,
847+ num_videos_per_prompt = num_videos_per_prompt ,
848+ max_sequence_length = max_sequence_length ,
849+ scale_factor = scale_factor ,
850+ dtype = dtype ,
851+ )
850852
851- if isinstance (combined_embeds , list ):
852- prompt_embeds = [state [:split_idx ] for state in combined_embeds ]
853- negative_prompt_embeds = [state [split_idx :] for state in combined_embeds ]
854- else :
855- prompt_embeds = combined_embeds [:split_idx ]
856- negative_prompt_embeds = combined_embeds [split_idx :]
853+ split_idx = batch_size * num_videos_per_prompt
857854
858- prompt_attention_mask = combined_mask [:split_idx ]
859- negative_prompt_attention_mask = combined_mask [split_idx :]
855+ if isinstance (combined_embeds , list ):
856+ prompt_embeds = [state [:split_idx ] for state in combined_embeds ]
857+ negative_prompt_embeds = [state [split_idx :] for state in combined_embeds ]
860858 else :
859+ prompt_embeds = combined_embeds [:split_idx ]
860+ negative_prompt_embeds = combined_embeds [split_idx :]
861+
862+ prompt_attention_mask = combined_mask [:split_idx ]
863+ negative_prompt_attention_mask = combined_mask [split_idx :]
864+ else :
865+ # Non-batched path (Sequential)
866+ if prompt_embeds is None :
861867 prompt_embeds , prompt_attention_mask = self ._get_gemma_prompt_embeds (
862868 prompt = prompt ,
863869 num_videos_per_prompt = num_videos_per_prompt ,
@@ -866,22 +872,22 @@ def encode_prompt(
866872 dtype = dtype ,
867873 )
868874
869- if do_classifier_free_guidance and negative_prompt_embeds is None :
870- negative_prompt = negative_prompt or ""
871- negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
875+ if do_classifier_free_guidance and negative_prompt_embeds is None :
876+ negative_prompt = negative_prompt or ""
877+ negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
872878
873- if prompt is not None and type (prompt ) is not type (negative_prompt ):
874- raise TypeError (
875- f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !=" f" { type (prompt )} ."
876- )
879+ if prompt is not None and type (prompt ) is not type (negative_prompt ):
880+ raise TypeError (
881+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !=" f" { type (prompt )} ."
882+ )
877883
878- negative_prompt_embeds , negative_prompt_attention_mask = self ._get_gemma_prompt_embeds (
879- prompt = negative_prompt ,
880- num_videos_per_prompt = num_videos_per_prompt ,
881- max_sequence_length = max_sequence_length ,
882- scale_factor = scale_factor ,
883- dtype = dtype ,
884- )
884+ negative_prompt_embeds , negative_prompt_attention_mask = self ._get_gemma_prompt_embeds (
885+ prompt = negative_prompt ,
886+ num_videos_per_prompt = num_videos_per_prompt ,
887+ max_sequence_length = max_sequence_length ,
888+ scale_factor = scale_factor ,
889+ dtype = dtype ,
890+ )
885891
886892 return prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask
887893
0 commit comments