Skip to content

Commit acc7452

Browse files
committed
Trying text_mask 1
1 parent 7543d00 commit acc7452

2 files changed

Lines changed: 41 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,18 +434,36 @@ def _get_t5_prompt_embeds(
434434
)
435435
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
436436
seq_lens = mask.gt(0).sum(dim=1).long()
437+
438+
# DEBUG
439+
print(f"[DEBUG _get_t5_prompt_embeds] seq_lens: {seq_lens.tolist()}, mask shape: {mask.shape}")
440+
437441
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
438442
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
439443
prompt_embeds = torch.stack(
440444
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
441445
)
442446

447+
# Create text attention mask
448+
text_attention_mask = torch.zeros((batch_size, max_sequence_length), dtype=torch.long)
449+
for i, seq_len_i in enumerate(seq_lens):
450+
text_attention_mask[i, :seq_len_i] = 1
451+
452+
print(f"[DEBUG _get_t5_prompt_embeds] text_attention_mask shape: {text_attention_mask.shape}, sum: {text_attention_mask.sum(dim=1).tolist()}")
453+
443454
# duplicate text embeddings for each generation per prompt, using mps friendly method
444455
_, seq_len, _ = prompt_embeds.shape
445456
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
446457
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
458+
459+
# Duplicate mask
460+
text_attention_mask = text_attention_mask.repeat(1, num_videos_per_prompt)
461+
text_attention_mask = text_attention_mask.view(batch_size * num_videos_per_prompt, max_sequence_length)
462+
text_attention_mask_jax = jnp.array(text_attention_mask.numpy())
463+
464+
print(f"[DEBUG _get_t5_prompt_embeds] After duplication - mask shape: {text_attention_mask_jax.shape}")
447465

448-
return prompt_embeds
466+
return prompt_embeds, text_attention_mask_jax
449467

450468
def encode_prompt(
451469
self,
@@ -459,24 +477,31 @@ def encode_prompt(
459477
prompt = [prompt] if isinstance(prompt, str) else prompt
460478
batch_size = len(prompt)
461479
if prompt_embeds is None:
462-
prompt_embeds = self._get_t5_prompt_embeds(
480+
prompt_embeds, text_attention_mask = self._get_t5_prompt_embeds(
463481
prompt=prompt,
464482
num_videos_per_prompt=num_videos_per_prompt,
465483
max_sequence_length=max_sequence_length,
466484
)
467485
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
486+
else:
487+
text_attention_mask = None
468488

469489
if negative_prompt_embeds is None:
470490
negative_prompt = negative_prompt or ""
471491
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
472-
negative_prompt_embeds = self._get_t5_prompt_embeds(
492+
negative_prompt_embeds, negative_text_attention_mask = self._get_t5_prompt_embeds(
473493
prompt=negative_prompt,
474494
num_videos_per_prompt=num_videos_per_prompt,
475495
max_sequence_length=max_sequence_length,
476496
)
477497
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32)
498+
else:
499+
negative_text_attention_mask = None
500+
501+
print(f"[DEBUG encode_prompt] text_attention_mask: {text_attention_mask.shape if text_attention_mask is not None else None}")
502+
print(f"[DEBUG encode_prompt] negative_text_attention_mask: {negative_text_attention_mask.shape if negative_text_attention_mask is not None else None}")
478503

479-
return prompt_embeds, negative_prompt_embeds
504+
return prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask
480505

481506
def prepare_latents(
482507
self,
@@ -687,13 +712,15 @@ def _prepare_model_inputs(
687712
batch_size = len(prompt)
688713

689714
with jax.named_scope("Encode-Prompt"):
690-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
715+
prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask = self.encode_prompt(
691716
prompt=prompt,
692717
negative_prompt=negative_prompt,
693718
max_sequence_length=max_sequence_length,
694719
prompt_embeds=prompt_embeds,
695720
negative_prompt_embeds=negative_prompt_embeds,
696721
)
722+
723+
print(f"[DEBUG _prepare_model_inputs] Got masks - text: {text_attention_mask.shape if text_attention_mask is not None else None}, neg: {negative_text_attention_mask.shape if negative_text_attention_mask is not None else None}")
697724

698725
num_channel_latents = self._get_num_channel_latents()
699726
if latents is None:
@@ -715,12 +742,16 @@ def _prepare_model_inputs(
715742
latents = jax.device_put(latents, data_sharding)
716743
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
717744
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
745+
if text_attention_mask is not None:
746+
text_attention_mask = jax.device_put(text_attention_mask, data_sharding)
747+
if negative_text_attention_mask is not None:
748+
negative_text_attention_mask = jax.device_put(negative_text_attention_mask, data_sharding)
718749

719750
scheduler_state = self.scheduler.set_timesteps(
720751
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
721752
)
722753

723-
return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames
754+
return latents, prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask, scheduler_state, num_frames
724755

725756
@abstractmethod
726757
def __call__(self, **kwargs):

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(
8989
negative_prompt_embeds: Optional[jax.Array] = None,
9090
vae_only: bool = False,
9191
):
92-
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
92+
latents, prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask, scheduler_state, num_frames = self._prepare_model_inputs(
9393
prompt,
9494
negative_prompt,
9595
height,
@@ -103,6 +103,9 @@ def __call__(
103103
negative_prompt_embeds,
104104
vae_only,
105105
)
106+
107+
print(f"[DEBUG WAN21T2V __call__] text_attention_mask: {text_attention_mask.shape if text_attention_mask is not None else None}")
108+
print(f"[DEBUG WAN21T2V __call__] negative_text_attention_mask: {negative_text_attention_mask.shape if negative_text_attention_mask is not None else None}")
106109

107110
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
108111

0 commit comments

Comments
 (0)