Skip to content

Commit bfec1e4

Browse files
committed
text attn fix
1 parent 7543d00 commit bfec1e4

4 files changed

Lines changed: 67 additions & 15 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ def __call__(
11101110
value_proj = checkpoint_name(value_proj, "value_proj")
11111111

11121112
with jax.named_scope("apply_attention"):
1113-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1113+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=attention_mask)
11141114

11151115
else:
11161116
# NEW PATH for I2V CROSS-ATTENTION
@@ -1134,14 +1134,17 @@ def __call__(
11341134
# It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
11351135
if encoder_attention_mask is not None:
11361136
encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
1137+
encoder_hidden_states_text = encoder_attention_mask[:,padded_img_len:]
11371138
else:
11381139
# Fallback: no mask means treat all as valid
11391140
encoder_attention_mask_img = None
1141+
encoder_hidden_states_text = None
11401142
else:
11411143
# If no image_seq_len is specified, treat all as text
11421144
encoder_hidden_states_img = None
11431145
encoder_hidden_states_text = encoder_hidden_states
11441146
encoder_attention_mask_img = None
1147+
encoder_attention_mask_text = encoder_attention_mask
11451148

11461149
if self.qk_norm:
11471150
with self.conditional_named_scope("attn_q_norm"):
@@ -1179,7 +1182,7 @@ def __call__(
11791182

11801183
# Attention - tensors are (B, S, D)
11811184
with self.conditional_named_scope("cross_attn_text_apply"):
1182-
attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text)
1185+
attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask_text)
11831186
with self.conditional_named_scope("cross_attn_img_apply"):
11841187
# Pass encoder_attention_mask_img for image cross-attention to mask padded tokens
11851188
attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img)
@@ -1192,7 +1195,7 @@ def __call__(
11921195
value_proj_text = checkpoint_name(value_proj_text, "value_proj_text")
11931196

11941197
with self.conditional_named_scope("cross_attn_text_apply"):
1195-
attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text)
1198+
attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask_text)
11961199

11971200
attn_output = attn_output.astype(dtype=dtype)
11981201
attn_output = checkpoint_name(attn_output, "attn_output")

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def __call__(
587587
timestep: jax.Array,
588588
encoder_hidden_states: jax.Array,
589589
encoder_hidden_states_image: Optional[jax.Array] = None,
590+
encoder_attention_mask: Optional[jax.Array] = None,
590591
return_dict: bool = True,
591592
attention_kwargs: Optional[Dict[str, Any]] = None,
592593
deterministic: bool = True,
@@ -606,17 +607,30 @@ def __call__(
606607
hidden_states = self.patch_embedding(hidden_states)
607608
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
608609
with self.conditional_named_scope("condition_embedder"):
609-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask = self.condition_embedder(
610+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, image_attention_mask = self.condition_embedder(
610611
timestep, encoder_hidden_states, encoder_hidden_states_image
611612
)
612613
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
613614

615+
# Handle attention mask for I2V vs T2V
614616
if encoder_hidden_states_image is not None:
617+
# I2V case: concatenate [image | text] embeddings
615618
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1)
616-
if encoder_attention_mask is not None:
617-
text_mask = jnp.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), dtype=jnp.int32)
618-
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
619+
620+
# Build combined attention mask: [image_mask | text_mask]
621+
if image_attention_mask is not None:
622+
# Image mask from embedder (e.g., [1]*257 + [0]*127 for padded image)
623+
if encoder_attention_mask is not None:
624+
# Use the text mask passed from pipeline
625+
combined_mask = jnp.concatenate([image_attention_mask, encoder_attention_mask], axis=1)
626+
else:
627+
# No text mask provided, assume all text tokens are valid (old behavior)
628+
text_len = encoder_hidden_states.shape[1] - image_attention_mask.shape[1]
629+
text_mask = jnp.ones((encoder_hidden_states.shape[0], text_len), dtype=jnp.int32)
630+
combined_mask = jnp.concatenate([image_attention_mask, text_mask], axis=1)
631+
encoder_attention_mask = combined_mask
619632
encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype)
633+
# For T2V: encoder_attention_mask is already the text mask passed from pipeline
620634

621635
if self.scan_layers:
622636

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -440,12 +440,24 @@ def _get_t5_prompt_embeds(
440440
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
441441
)
442442

443+
# Create attention mask: 1 for real tokens, 0 for padded tokens
444+
text_attention_mask = torch.zeros((batch_size, max_sequence_length), dtype=torch.long)
445+
for i, seq_len_i in enumerate(seq_lens):
446+
text_attention_mask[i, :seq_len_i] = 1
447+
443448
# duplicate text embeddings for each generation per prompt, using mps friendly method
444449
_, seq_len, _ = prompt_embeds.shape
445450
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
446451
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
452+
453+
# Duplicate attention mask for each generation per prompt
454+
text_attention_mask = text_attention_mask.repeat(1, num_videos_per_prompt)
455+
text_attention_mask = text_attention_mask.view(batch_size * num_videos_per_prompt, max_sequence_length)
456+
457+
# Convert to JAX array
458+
text_attention_mask = jnp.array(text_attention_mask.numpy())
447459

448-
return prompt_embeds
460+
return prompt_embeds, text_attention_mask
449461

450462
def encode_prompt(
451463
self,
@@ -459,24 +471,28 @@ def encode_prompt(
459471
prompt = [prompt] if isinstance(prompt, str) else prompt
460472
batch_size = len(prompt)
461473
if prompt_embeds is None:
462-
prompt_embeds = self._get_t5_prompt_embeds(
474+
prompt_embeds, text_attention_mask = self._get_t5_prompt_embeds(
463475
prompt=prompt,
464476
num_videos_per_prompt=num_videos_per_prompt,
465477
max_sequence_length=max_sequence_length,
466478
)
467479
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
480+
else:
481+
text_attention_mask = None
468482

469483
if negative_prompt_embeds is None:
470484
negative_prompt = negative_prompt or ""
471485
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
472-
negative_prompt_embeds = self._get_t5_prompt_embeds(
486+
negative_prompt_embeds, negative_text_attention_mask = self._get_t5_prompt_embeds(
473487
prompt=negative_prompt,
474488
num_videos_per_prompt=num_videos_per_prompt,
475489
max_sequence_length=max_sequence_length,
476490
)
477491
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32)
492+
else:
493+
negative_text_attention_mask = None
478494

479-
return prompt_embeds, negative_prompt_embeds
495+
return prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask
480496

481497
def prepare_latents(
482498
self,
@@ -687,7 +703,7 @@ def _prepare_model_inputs(
687703
batch_size = len(prompt)
688704

689705
with jax.named_scope("Encode-Prompt"):
690-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
706+
prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask = self.encode_prompt(
691707
prompt=prompt,
692708
negative_prompt=negative_prompt,
693709
max_sequence_length=max_sequence_length,
@@ -715,12 +731,16 @@ def _prepare_model_inputs(
715731
latents = jax.device_put(latents, data_sharding)
716732
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
717733
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
734+
if text_attention_mask is not None:
735+
text_attention_mask = jax.device_put(text_attention_mask, data_sharding)
736+
if negative_text_attention_mask is not None:
737+
negative_text_attention_mask = jax.device_put(negative_text_attention_mask, data_sharding)
718738

719739
scheduler_state = self.scheduler.set_timesteps(
720740
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
721741
)
722742

723-
return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames
743+
return latents, prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask, scheduler_state, num_frames
724744

725745
@abstractmethod
726746
def __call__(self, **kwargs):
@@ -738,9 +758,16 @@ def transformer_forward_pass(
738758
do_classifier_free_guidance,
739759
guidance_scale,
740760
encoder_hidden_states_image=None,
761+
encoder_attention_mask=None,
741762
):
742763
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
743-
noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image)
764+
noise_pred = wan_transformer(
765+
hidden_states=latents,
766+
timestep=timestep,
767+
encoder_hidden_states=prompt_embeds,
768+
encoder_hidden_states_image=encoder_hidden_states_image,
769+
encoder_attention_mask=encoder_attention_mask,
770+
)
744771
if do_classifier_free_guidance:
745772
bsz = latents.shape[0] // 2
746773
noise_cond = noise_pred[:bsz] # First half = conditional

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(
8989
negative_prompt_embeds: Optional[jax.Array] = None,
9090
vae_only: bool = False,
9191
):
92-
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
92+
latents, prompt_embeds, negative_prompt_embeds, text_attention_mask, negative_text_attention_mask, scheduler_state, num_frames = self._prepare_model_inputs(
9393
prompt,
9494
negative_prompt,
9595
height,
@@ -122,6 +122,8 @@ def __call__(
122122
latents=latents,
123123
prompt_embeds=prompt_embeds,
124124
negative_prompt_embeds=negative_prompt_embeds,
125+
text_attention_mask=text_attention_mask,
126+
negative_text_attention_mask=negative_text_attention_mask,
125127
)
126128
latents = self._denormalize_latents(latents)
127129
return self._decode_latents_to_video(latents)
@@ -137,10 +139,15 @@ def run_inference_2_1(
137139
num_inference_steps: int,
138140
scheduler: FlaxUniPCMultistepScheduler,
139141
scheduler_state,
142+
text_attention_mask: Optional[jnp.array] = None,
143+
negative_text_attention_mask: Optional[jnp.array] = None,
140144
):
141145
do_classifier_free_guidance = guidance_scale > 1.0
142146
if do_classifier_free_guidance:
143147
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
148+
# Concatenate text attention masks for CFG
149+
if text_attention_mask is not None and negative_text_attention_mask is not None:
150+
text_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
144151
for step in range(num_inference_steps):
145152
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
146153
if do_classifier_free_guidance:
@@ -156,6 +163,7 @@ def run_inference_2_1(
156163
prompt_embeds,
157164
do_classifier_free_guidance=do_classifier_free_guidance,
158165
guidance_scale=guidance_scale,
166+
encoder_attention_mask=text_attention_mask,
159167
)
160168

161169
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

0 commit comments

Comments
 (0)