Skip to content

Commit 19a2dc6

Browse files
committed
text encoder batching
1 parent c3dd9fe commit 19a2dc6

1 file changed

Lines changed: 37 additions & 8 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -832,14 +832,43 @@ def encode_prompt(
832832
batch_size = prompt_embeds.shape[0]
833833

834834
if prompt_embeds is None:
835-
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
836-
prompt=prompt,
837-
num_videos_per_prompt=num_videos_per_prompt,
838-
max_sequence_length=max_sequence_length,
839-
scale_factor=scale_factor,
840-
dtype=dtype,
841-
)
842-
835+
if do_classifier_free_guidance and negative_prompt_embeds is None:
836+
negative_prompt = negative_prompt or ""
837+
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
838+
839+
if isinstance(prompt, str):
840+
prompt = [prompt]
841+
842+
combined_prompts = prompt + negative_prompt
843+
844+
combined_embeds, combined_mask = self._get_gemma_prompt_embeds(
845+
prompt=combined_prompts,
846+
num_videos_per_prompt=num_videos_per_prompt,
847+
max_sequence_length=max_sequence_length,
848+
scale_factor=scale_factor,
849+
dtype=dtype,
850+
)
851+
852+
split_idx = batch_size * num_videos_per_prompt
853+
854+
if isinstance(combined_embeds, list):
855+
prompt_embeds = [state[:split_idx] for state in combined_embeds]
856+
negative_prompt_embeds = [state[split_idx:] for state in combined_embeds]
857+
else:
858+
prompt_embeds = combined_embeds[:split_idx]
859+
negative_prompt_embeds = combined_embeds[split_idx:]
860+
861+
prompt_attention_mask = combined_mask[:split_idx]
862+
negative_prompt_attention_mask = combined_mask[split_idx:]
863+
else:
864+
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
865+
prompt=prompt,
866+
num_videos_per_prompt=num_videos_per_prompt,
867+
max_sequence_length=max_sequence_length,
868+
scale_factor=scale_factor,
869+
dtype=dtype,
870+
)
871+
843872
if do_classifier_free_guidance and negative_prompt_embeds is None:
844873
negative_prompt = negative_prompt or ""
845874
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt

0 commit comments

Comments
 (0)