Skip to content

Commit 1463ea9

Browse files
committed
Trying text_mask 9
1 parent 1fbfd5b commit 1463ea9

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def _prepare_model_inputs_i2v(
633633
effective_batch_size = batch_size * num_videos_per_prompt
634634

635635
# 1. Encode Prompts
636-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
636+
prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask = self.encode_prompt(
637637
prompt=prompt,
638638
negative_prompt=negative_prompt,
639639
num_videos_per_prompt=num_videos_per_prompt,
@@ -677,9 +677,12 @@ def _prepare_model_inputs_i2v(
677677

678678
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
679679
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
680+
text_attention_mask = jax.device_put(text_attention_mask, data_sharding)
681+
negative_text_attention_mask = jax.device_put(negative_text_attention_mask, data_sharding)
680682
image_embeds = jax.device_put(image_embeds, data_sharding)
681683

682-
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size
684+
return (prompt_embeds, negative_prompt_embeds, text_attention_mask,
685+
negative_text_attention_mask, image_embeds, effective_batch_size)
683686

684687

685688
def _prepare_model_inputs(

0 commit comments

Comments
 (0)