Skip to content

Commit 3538f1a

Browse files
committed
text attn mask fix
1 parent a1f291c commit 3538f1a

6 files changed

Lines changed: 103 additions & 42 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,24 @@ def __init__(
151151
)
152152

153153
def __call__(
154-
self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None
154+
self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None, text_attention_mask: Optional[jax.Array] = None
155155
):
156156
timestep = self.timesteps_proj(timestep)
157157
temb = self.time_embedder(timestep)
158158
with jax.named_scope("time_proj"):
159159
timestep_proj = self.time_proj(self.act_fn(temb))
160160

161161
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
162-
encoder_attention_mask = None
162+
# Start with text attention mask (can be None for backward compatibility)
163+
encoder_attention_mask = text_attention_mask
164+
163165
if encoder_hidden_states_image is not None:
164-
encoder_hidden_states_image, encoder_attention_mask = self.image_embedder(encoder_hidden_states_image)
165-
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask
166+
# For I2V: image embedder returns image embeddings and image mask
167+
encoder_hidden_states_image, image_attention_mask = self.image_embedder(encoder_hidden_states_image)
168+
# Store image mask separately - will be concatenated with text mask in WanModel
169+
encoder_attention_mask = image_attention_mask
170+
171+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask, text_attention_mask
166172

167173

168174
class ApproximateGELU(nnx.Module):
@@ -373,6 +379,7 @@ def __call__(
373379
rotary_emb: jax.Array,
374380
deterministic: bool = True,
375381
rngs: nnx.Rngs = None,
382+
encoder_attention_mask: Optional[jax.Array] = None,
376383
):
377384
with self.conditional_named_scope("transformer_block"):
378385
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
@@ -409,6 +416,7 @@ def __call__(
409416
encoder_hidden_states=encoder_hidden_states,
410417
deterministic=deterministic,
411418
rngs=rngs,
419+
encoder_attention_mask=encoder_attention_mask,
412420
)
413421
with self.conditional_named_scope("cross_attn_residual"):
414422
hidden_states = hidden_states + attn_output
@@ -585,6 +593,7 @@ def __call__(
585593
timestep: jax.Array,
586594
encoder_hidden_states: jax.Array,
587595
encoder_hidden_states_image: Optional[jax.Array] = None,
596+
text_attention_mask: Optional[jax.Array] = None,
588597
return_dict: bool = True,
589598
attention_kwargs: Optional[Dict[str, Any]] = None,
590599
deterministic: bool = True,
@@ -604,24 +613,36 @@ def __call__(
604613
hidden_states = self.patch_embedding(hidden_states)
605614
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
606615
with self.conditional_named_scope("condition_embedder"):
607-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask = self.condition_embedder(
608-
timestep, encoder_hidden_states, encoder_hidden_states_image
616+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask, text_mask_from_embedder = self.condition_embedder(
617+
timestep, encoder_hidden_states, encoder_hidden_states_image, text_attention_mask
609618
)
610619
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
611620

621+
# Use text_attention_mask if provided, otherwise fall back to text_mask_from_embedder
622+
if text_attention_mask is None:
623+
text_attention_mask = text_mask_from_embedder
624+
612625
if encoder_hidden_states_image is not None:
626+
# I2V case: concatenate image + text embeddings and their masks
613627
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1)
614-
if encoder_attention_mask is not None:
615-
text_mask = jnp.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), dtype=jnp.int32)
616-
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
628+
if encoder_attention_mask is not None and text_attention_mask is not None:
629+
# Concatenate image mask + text mask (both are real masks now!)
630+
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_attention_mask], axis=1)
631+
elif text_attention_mask is not None:
632+
# Only text mask available (shouldn't happen in I2V, but handle gracefully)
633+
encoder_attention_mask = text_attention_mask
634+
# else: encoder_attention_mask remains as-is (image mask only)
617635
encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype)
636+
elif text_attention_mask is not None:
637+
# T2V case: only text, use text mask directly
638+
encoder_attention_mask = text_attention_mask
618639

619640
if self.scan_layers:
620641

621642
def scan_fn(carry, block):
622643
hidden_states_carry, rngs_carry = carry
623644
hidden_states = block(
624-
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry
645+
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry, encoder_attention_mask
625646
)
626647
new_carry = (hidden_states, rngs_carry)
627648
return new_carry, None

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,6 @@ def _get_t5_prompt_embeds(
419419
num_videos_per_prompt: int = 1,
420420
max_sequence_length: int = 226,
421421
):
422-
jax.debug.print("Prompt Shape")
423-
jax.debug.print(f"Length of prompt list: {len(prompt)}")
424-
425422
prompt = [prompt] if isinstance(prompt, str) else prompt
426423
prompt = [prompt_clean(u) for u in prompt]
427424
batch_size = len(prompt)
@@ -436,25 +433,29 @@ def _get_t5_prompt_embeds(
436433
return_tensors="pt",
437434
)
438435
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
439-
print("Text Input IDS")
440-
print(text_input_ids)
441-
print(text_input_ids.shape)
442-
print("Mask")
443-
print(mask)
444-
print(mask.shape)
445436
seq_lens = mask.gt(0).sum(dim=1).long()
446437
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
447438
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
448439
prompt_embeds = torch.stack(
449440
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
450441
)
451442

443+
# Create attention mask: 1 for real tokens, 0 for padded tokens
444+
# This mask reflects the actual content after trimming and re-padding with zeros
445+
text_attention_mask = torch.zeros((batch_size, max_sequence_length), dtype=torch.int32, device=mask.device)
446+
for i, length in enumerate(seq_lens):
447+
text_attention_mask[i, :length] = 1
448+
452449
# duplicate text embeddings for each generation per prompt, using mps friendly method
453450
_, seq_len, _ = prompt_embeds.shape
454451
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
455452
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
453+
454+
# duplicate attention mask for each generation per prompt
455+
text_attention_mask = text_attention_mask.repeat(1, num_videos_per_prompt)
456+
text_attention_mask = text_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
456457

457-
return prompt_embeds
458+
return prompt_embeds, text_attention_mask
458459

459460
def encode_prompt(
460461
self,
@@ -467,25 +468,30 @@ def encode_prompt(
467468
):
468469
prompt = [prompt] if isinstance(prompt, str) else prompt
469470
batch_size = len(prompt)
471+
text_attention_mask = None
472+
negative_text_attention_mask = None
473+
470474
if prompt_embeds is None:
471-
prompt_embeds = self._get_t5_prompt_embeds(
475+
prompt_embeds, text_attention_mask = self._get_t5_prompt_embeds(
472476
prompt=prompt,
473477
num_videos_per_prompt=num_videos_per_prompt,
474478
max_sequence_length=max_sequence_length,
475479
)
476480
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
481+
text_attention_mask = jnp.array(text_attention_mask.detach().numpy(), dtype=jnp.int32)
477482

478483
if negative_prompt_embeds is None:
479484
negative_prompt = negative_prompt or ""
480485
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
481-
negative_prompt_embeds = self._get_t5_prompt_embeds(
486+
negative_prompt_embeds, negative_text_attention_mask = self._get_t5_prompt_embeds(
482487
prompt=negative_prompt,
483488
num_videos_per_prompt=num_videos_per_prompt,
484489
max_sequence_length=max_sequence_length,
485490
)
486491
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32)
492+
negative_text_attention_mask = jnp.array(negative_text_attention_mask.detach().numpy(), dtype=jnp.int32)
487493

488-
return prompt_embeds, negative_prompt_embeds
494+
return prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask
489495

490496
def prepare_latents(
491497
self,
@@ -617,7 +623,7 @@ def _prepare_model_inputs_i2v(
617623
effective_batch_size = batch_size * num_videos_per_prompt
618624

619625
# 1. Encode Prompts
620-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
626+
prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask = self.encode_prompt(
621627
prompt=prompt,
622628
negative_prompt=negative_prompt,
623629
num_videos_per_prompt=num_videos_per_prompt,
@@ -662,8 +668,10 @@ def _prepare_model_inputs_i2v(
662668
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
663669
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
664670
image_embeds = jax.device_put(image_embeds, data_sharding)
671+
text_attention_mask = jax.device_put(text_attention_mask, data_sharding)
672+
negative_text_attention_mask = jax.device_put(negative_text_attention_mask, data_sharding)
665673

666-
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size
674+
return prompt_embeds, negative_prompt_embeds, image_embeds, text_attention_mask, negative_text_attention_mask, effective_batch_size
667675

668676

669677
def _prepare_model_inputs(
@@ -696,7 +704,7 @@ def _prepare_model_inputs(
696704
batch_size = len(prompt)
697705

698706
with jax.named_scope("Encode-Prompt"):
699-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
707+
prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask = self.encode_prompt(
700708
prompt=prompt,
701709
negative_prompt=negative_prompt,
702710
max_sequence_length=max_sequence_length,
@@ -724,12 +732,14 @@ def _prepare_model_inputs(
724732
latents = jax.device_put(latents, data_sharding)
725733
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
726734
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
735+
text_attention_mask = jax.device_put(text_attention_mask, data_sharding)
736+
negative_text_attention_mask = jax.device_put(negative_text_attention_mask, data_sharding)
727737

728738
scheduler_state = self.scheduler.set_timesteps(
729739
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
730740
)
731741

732-
return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames
742+
return latents, prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask, scheduler_state, num_frames
733743

734744
@abstractmethod
735745
def __call__(self, **kwargs):
@@ -747,9 +757,10 @@ def transformer_forward_pass(
747757
do_classifier_free_guidance,
748758
guidance_scale,
749759
encoder_hidden_states_image=None,
760+
text_attention_mask=None,
750761
):
751762
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
752-
noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image)
763+
noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image, text_attention_mask=text_attention_mask)
753764
if do_classifier_free_guidance:
754765
bsz = latents.shape[0] // 2
755766
noise_cond = noise_pred[:bsz] # First half = conditional

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(
8989
negative_prompt_embeds: Optional[jax.Array] = None,
9090
vae_only: bool = False,
9191
):
92-
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
92+
latents, prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask, scheduler_state, num_frames = self._prepare_model_inputs(
9393
prompt,
9494
negative_prompt,
9595
height,
@@ -122,6 +122,8 @@ def __call__(
122122
latents=latents,
123123
prompt_embeds=prompt_embeds,
124124
negative_prompt_embeds=negative_prompt_embeds,
125+
text_attention_mask=text_attention_mask,
126+
negative_text_attention_mask=negative_text_attention_mask,
125127
)
126128
latents = self._denormalize_latents(latents)
127129
return self._decode_latents_to_video(latents)
@@ -133,6 +135,8 @@ def run_inference_2_1(
133135
latents: jnp.array,
134136
prompt_embeds: jnp.array,
135137
negative_prompt_embeds: jnp.array,
138+
text_attention_mask: jnp.array,
139+
negative_text_attention_mask: jnp.array,
136140
guidance_scale: float,
137141
num_inference_steps: int,
138142
scheduler: FlaxUniPCMultistepScheduler,
@@ -141,6 +145,7 @@ def run_inference_2_1(
141145
do_classifier_free_guidance = guidance_scale > 1.0
142146
if do_classifier_free_guidance:
143147
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
148+
text_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
144149
for step in range(num_inference_steps):
145150
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
146151
if do_classifier_free_guidance:
@@ -156,6 +161,7 @@ def run_inference_2_1(
156161
prompt_embeds,
157162
do_classifier_free_guidance=do_classifier_free_guidance,
158163
guidance_scale=guidance_scale,
164+
text_attention_mask=text_attention_mask,
159165
)
160166

161167
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __call__(
100100
negative_prompt_embeds: jax.Array = None,
101101
vae_only: bool = False,
102102
):
103-
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
103+
latents, prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask, scheduler_state, num_frames = self._prepare_model_inputs(
104104
prompt,
105105
negative_prompt,
106106
height,
@@ -139,6 +139,8 @@ def __call__(
139139
latents=latents,
140140
prompt_embeds=prompt_embeds,
141141
negative_prompt_embeds=negative_prompt_embeds,
142+
text_attention_mask=text_attention_mask,
143+
negative_text_attention_mask=negative_text_attention_mask,
142144
)
143145
latents = self._denormalize_latents(latents)
144146
return self._decode_latents_to_video(latents)
@@ -153,6 +155,8 @@ def run_inference_2_2(
153155
latents: jnp.array,
154156
prompt_embeds: jnp.array,
155157
negative_prompt_embeds: jnp.array,
158+
text_attention_mask: jnp.array,
159+
negative_text_attention_mask: jnp.array,
156160
guidance_scale_low: float,
157161
guidance_scale_high: float,
158162
boundary: int,
@@ -163,21 +167,24 @@ def run_inference_2_2(
163167
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
164168
if do_classifier_free_guidance:
165169
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
170+
text_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
166171

167172
def low_noise_branch(operands):
168-
latents, timestep, prompt_embeds = operands
173+
latents, timestep, prompt_embeds, text_attention_mask = operands
169174
return transformer_forward_pass(
170175
low_noise_graphdef, low_noise_state, low_noise_rest,
171176
latents, timestep, prompt_embeds,
172-
do_classifier_free_guidance, guidance_scale_low
177+
do_classifier_free_guidance, guidance_scale_low,
178+
text_attention_mask=text_attention_mask
173179
)
174180

175181
def high_noise_branch(operands):
176-
latents, timestep, prompt_embeds = operands
182+
latents, timestep, prompt_embeds, text_attention_mask = operands
177183
return transformer_forward_pass(
178184
high_noise_graphdef, high_noise_state, high_noise_rest,
179185
latents, timestep, prompt_embeds,
180-
do_classifier_free_guidance, guidance_scale_high
186+
do_classifier_free_guidance, guidance_scale_high,
187+
text_attention_mask=text_attention_mask
181188
)
182189

183190
for step in range(num_inference_steps):
@@ -195,7 +202,7 @@ def high_noise_branch(operands):
195202
use_high_noise,
196203
high_noise_branch,
197204
low_noise_branch,
198-
(latents, timestep, prompt_embeds)
205+
(latents, timestep, prompt_embeds, text_attention_mask)
199206
)
200207

201208
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __call__(
171171
max_logging.log(f"Adjusted num_frames to: {num_frames}")
172172
num_frames = max(num_frames, 1)
173173

174-
prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v(
174+
prompt_embeds, negative_prompt_embeds, image_embeds, text_attention_mask, negative_text_attention_mask, effective_batch_size = self._prepare_model_inputs_i2v(
175175
prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length,
176176
prompt_embeds, negative_prompt_embeds, image_embeds, last_image
177177
)
@@ -212,6 +212,8 @@ def __call__(
212212
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
213213
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
214214
image_embeds = jax.device_put(image_embeds, data_sharding)
215+
text_attention_mask = jax.device_put(text_attention_mask, data_sharding)
216+
negative_text_attention_mask = jax.device_put(negative_text_attention_mask, data_sharding)
215217
if first_frame_mask is not None:
216218
first_frame_mask = jax.device_put(first_frame_mask, data_sharding)
217219

@@ -233,6 +235,8 @@ def __call__(
233235
prompt_embeds=prompt_embeds,
234236
negative_prompt_embeds=negative_prompt_embeds,
235237
image_embeds=image_embeds,
238+
text_attention_mask=text_attention_mask,
239+
negative_text_attention_mask=negative_text_attention_mask,
236240
first_frame_mask=first_frame_mask,
237241
scheduler_state=scheduler_state,
238242
rng=inference_rng,
@@ -252,6 +256,8 @@ def run_inference_2_1_i2v(
252256
prompt_embeds: jnp.array,
253257
negative_prompt_embeds: jnp.array,
254258
image_embeds: jnp.array,
259+
text_attention_mask: jnp.array,
260+
negative_text_attention_mask: jnp.array,
255261
guidance_scale: float,
256262
num_inference_steps: int,
257263
scheduler: FlaxUniPCMultistepScheduler,
@@ -265,6 +271,7 @@ def run_inference_2_1_i2v(
265271
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
266272
image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0)
267273
condition = jnp.concatenate([condition] * 2)
274+
text_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
268275
for step in range(num_inference_steps):
269276
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
270277
latents_input = latents
@@ -280,6 +287,7 @@ def run_inference_2_1_i2v(
280287
do_classifier_free_guidance=do_classifier_free_guidance,
281288
guidance_scale=guidance_scale,
282289
encoder_hidden_states_image=image_embeds,
290+
text_attention_mask=text_attention_mask,
283291
)
284292
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
285293
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents, return_dict=False)

0 commit comments

Comments
 (0)