Skip to content

Commit c53cd1b

Browse files
committed
text encoder batching only for v7x, not v6e
1 parent 67cfc0f commit c53cd1b

1 file changed

Lines changed: 43 additions & 37 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from flax import nnx
2929
from flax.linen import partitioning as nn_partitioning
3030
from transformers import AutoTokenizer, GemmaTokenizer, GemmaTokenizerFast, Gemma3ForConditionalGeneration
31+
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
3132
import qwix
3233
from ...utils import logging
3334
from ...schedulers import FlaxFlowMatchScheduler
@@ -828,36 +829,41 @@ def encode_prompt(
828829
else:
829830
batch_size = prompt_embeds.shape[0]
830831

831-
if prompt_embeds is None:
832-
if do_classifier_free_guidance and negative_prompt_embeds is None:
833-
negative_prompt = negative_prompt or ""
834-
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
832+
tpu_type = get_tpu_type()
833+
# Batching text encoder gives better results on Ironwood (v7x) but poor on Trillium (v6e)
834+
use_batched_text_encoder = tpu_type == TpuType.TPU_7X
835835

836-
if isinstance(prompt, str):
837-
prompt = [prompt]
836+
if use_batched_text_encoder and prompt_embeds is None and do_classifier_free_guidance and negative_prompt_embeds is None:
837+
negative_prompt = negative_prompt or ""
838+
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
838839

839-
combined_prompts = prompt + negative_prompt
840+
if isinstance(prompt, str):
841+
prompt = [prompt]
840842

841-
combined_embeds, combined_mask = self._get_gemma_prompt_embeds(
842-
prompt=combined_prompts,
843-
num_videos_per_prompt=num_videos_per_prompt,
844-
max_sequence_length=max_sequence_length,
845-
scale_factor=scale_factor,
846-
dtype=dtype,
847-
)
843+
combined_prompts = prompt + negative_prompt
848844

849-
split_idx = batch_size * num_videos_per_prompt
845+
combined_embeds, combined_mask = self._get_gemma_prompt_embeds(
846+
prompt=combined_prompts,
847+
num_videos_per_prompt=num_videos_per_prompt,
848+
max_sequence_length=max_sequence_length,
849+
scale_factor=scale_factor,
850+
dtype=dtype,
851+
)
850852

851-
if isinstance(combined_embeds, list):
852-
prompt_embeds = [state[:split_idx] for state in combined_embeds]
853-
negative_prompt_embeds = [state[split_idx:] for state in combined_embeds]
854-
else:
855-
prompt_embeds = combined_embeds[:split_idx]
856-
negative_prompt_embeds = combined_embeds[split_idx:]
853+
split_idx = batch_size * num_videos_per_prompt
857854

858-
prompt_attention_mask = combined_mask[:split_idx]
859-
negative_prompt_attention_mask = combined_mask[split_idx:]
855+
if isinstance(combined_embeds, list):
856+
prompt_embeds = [state[:split_idx] for state in combined_embeds]
857+
negative_prompt_embeds = [state[split_idx:] for state in combined_embeds]
860858
else:
859+
prompt_embeds = combined_embeds[:split_idx]
860+
negative_prompt_embeds = combined_embeds[split_idx:]
861+
862+
prompt_attention_mask = combined_mask[:split_idx]
863+
negative_prompt_attention_mask = combined_mask[split_idx:]
864+
else:
865+
# Non-batched path (Sequential)
866+
if prompt_embeds is None:
861867
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
862868
prompt=prompt,
863869
num_videos_per_prompt=num_videos_per_prompt,
@@ -866,22 +872,22 @@ def encode_prompt(
866872
dtype=dtype,
867873
)
868874

869-
if do_classifier_free_guidance and negative_prompt_embeds is None:
870-
negative_prompt = negative_prompt or ""
871-
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
875+
if do_classifier_free_guidance and negative_prompt_embeds is None:
876+
negative_prompt = negative_prompt or ""
877+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
872878

873-
if prompt is not None and type(prompt) is not type(negative_prompt):
874-
raise TypeError(
875-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}."
876-
)
879+
if prompt is not None and type(prompt) is not type(negative_prompt):
880+
raise TypeError(
881+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}."
882+
)
877883

878-
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
879-
prompt=negative_prompt,
880-
num_videos_per_prompt=num_videos_per_prompt,
881-
max_sequence_length=max_sequence_length,
882-
scale_factor=scale_factor,
883-
dtype=dtype,
884-
)
884+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
885+
prompt=negative_prompt,
886+
num_videos_per_prompt=num_videos_per_prompt,
887+
max_sequence_length=max_sequence_length,
888+
scale_factor=scale_factor,
889+
dtype=dtype,
890+
)
885891

886892
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
887893

0 commit comments

Comments
 (0)