Skip to content

Commit 1590039

Browse files
committed
Add attention mask support for Wan model
1 parent 1d93c63 commit 1590039

7 files changed

Lines changed: 109 additions & 32 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def __call__(
11791179
value_proj = checkpoint_name(value_proj, "value_proj")
11801180

11811181
with jax.named_scope("apply_attention"):
1182-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1182+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask)
11831183

11841184
else:
11851185
# NEW PATH for I2V CROSS-ATTENTION
@@ -1206,9 +1206,11 @@ def __call__(
12061206
# It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
12071207
if encoder_attention_mask is not None:
12081208
encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
1209+
encoder_attention_mask_text = encoder_attention_mask[:, padded_img_len:]
12091210
else:
12101211
# Fallback: no mask means treat all as valid (for dot product attention)
12111212
encoder_attention_mask_img = None
1213+
encoder_attention_mask_text = None
12121214
else:
12131215
# If no image_seq_len is specified, treat all as text
12141216
encoder_hidden_states_img = None
@@ -1257,7 +1259,7 @@ def __call__(
12571259

12581260
# Attention - tensors are (B, S, D)
12591261
with self.conditional_named_scope("cross_attn_text_apply"):
1260-
attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text)
1262+
attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask_text)
12611263
with self.conditional_named_scope("cross_attn_img_apply"):
12621264
# Pass encoder_attention_mask_img for image cross-attention to mask padded tokens
12631265
attn_output_img = self.attention_op.apply_attention(
@@ -1272,7 +1274,7 @@ def __call__(
12721274
value_proj_text = checkpoint_name(value_proj_text, "value_proj_text")
12731275

12741276
with self.conditional_named_scope("cross_attn_text_apply"):
1275-
attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text)
1277+
attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask)
12761278

12771279
attn_output = attn_output.astype(dtype=dtype)
12781280
attn_output = checkpoint_name(attn_output, "attn_output")

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def compute_kv_cache(
606606
encoder_hidden_states: jax.Array,
607607
encoder_hidden_states_image: Optional[jax.Array] = None,
608608
timestep: Optional[jax.Array] = None,
609+
text_mask: Optional[jax.Array] = None,
609610
) -> Tuple[Dict[str, Tuple[jax.Array, jax.Array]], Optional[jax.Array]]:
610611
if timestep is None:
611612
batch_size = encoder_hidden_states.shape[0]
@@ -623,11 +624,15 @@ def compute_kv_cache(
623624
if encoder_hidden_states_image is not None:
624625
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1)
625626
if encoder_attention_mask is not None:
626-
text_mask = jnp.ones(
627-
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]),
628-
dtype=jnp.int32,
629-
)
627+
if text_mask is None:
628+
text_mask = jnp.ones(
629+
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]),
630+
dtype=jnp.int32,
631+
)
630632
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
633+
else:
634+
if encoder_attention_mask is None:
635+
encoder_attention_mask = text_mask
631636

632637
if self.scan_layers:
633638
@nnx.vmap(in_axes=(0, None, None), out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"})
@@ -665,6 +670,7 @@ def __call__(
665670
kv_cache: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None,
666671
rotary_emb: Optional[jax.Array] = None,
667672
encoder_attention_mask: Optional[jax.Array] = None,
673+
text_mask: Optional[jax.Array] = None,
668674
) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]:
669675
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
670676
batch_size, _, num_frames, height, width = hidden_states.shape
@@ -694,14 +700,17 @@ def __call__(
694700
encoder_attention_mask = encoder_attention_mask
695701
else:
696702
encoder_attention_mask = encoder_attention_mask_out
703+
if encoder_attention_mask is None:
704+
encoder_attention_mask = text_mask
697705

698706
if encoder_hidden_states_image is not None:
699707
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states_out], axis=1)
700708
if kv_cache is None and encoder_attention_mask is not None:
701-
text_mask = jnp.ones(
702-
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]),
703-
dtype=jnp.int32,
704-
)
709+
if text_mask is None:
710+
text_mask = jnp.ones(
711+
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]),
712+
dtype=jnp.int32,
713+
)
705714
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
706715
encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype)
707716
else:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ def _get_t5_prompt_embeds(
473473
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
474474
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
475475

476-
return prompt_embeds
476+
mask = mask.repeat(1, num_videos_per_prompt)
477+
mask = mask.view(batch_size * num_videos_per_prompt, seq_len)
478+
mask = jnp.array(mask.detach().numpy(), dtype=jnp.int32)
479+
480+
return prompt_embeds, mask
477481

478482
def encode_prompt(
479483
self,
@@ -483,28 +487,36 @@ def encode_prompt(
483487
max_sequence_length: int = 226,
484488
prompt_embeds: jax.Array = None,
485489
negative_prompt_embeds: jax.Array = None,
490+
prompt_mask: jax.Array = None,
491+
negative_prompt_mask: jax.Array = None,
486492
):
487493
prompt = [prompt] if isinstance(prompt, str) else prompt
488494
if prompt_embeds is None:
489-
prompt_embeds = self._get_t5_prompt_embeds(
495+
prompt_embeds, prompt_mask = self._get_t5_prompt_embeds(
490496
prompt=prompt,
491497
num_videos_per_prompt=num_videos_per_prompt,
492498
max_sequence_length=max_sequence_length,
493499
)
494500
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
501+
else:
502+
if prompt_mask is None:
503+
prompt_mask = jnp.ones((prompt_embeds.shape[0], prompt_embeds.shape[1]), dtype=jnp.int32)
495504

496505
if negative_prompt_embeds is None:
497506
batch_size = len(prompt_embeds)
498507
negative_prompt = negative_prompt or ""
499508
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
500-
negative_prompt_embeds = self._get_t5_prompt_embeds(
509+
negative_prompt_embeds, negative_prompt_mask = self._get_t5_prompt_embeds(
501510
prompt=negative_prompt,
502511
num_videos_per_prompt=num_videos_per_prompt,
503512
max_sequence_length=max_sequence_length,
504513
)
505514
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32)
515+
else:
516+
if negative_prompt_mask is None:
517+
negative_prompt_mask = jnp.ones((negative_prompt_embeds.shape[0], negative_prompt_embeds.shape[1]), dtype=jnp.int32)
506518

507-
return prompt_embeds, negative_prompt_embeds
519+
return prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask
508520

509521
def prepare_latents(
510522
self,
@@ -647,7 +659,7 @@ def _prepare_model_inputs_i2v(
647659
effective_batch_size = batch_size * num_videos_per_prompt
648660

649661
# 1. Encode Prompts
650-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
662+
prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self.encode_prompt(
651663
prompt=prompt,
652664
negative_prompt=negative_prompt,
653665
num_videos_per_prompt=num_videos_per_prompt,
@@ -691,8 +703,10 @@ def _prepare_model_inputs_i2v(
691703
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
692704
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
693705
image_embeds = jax.device_put(image_embeds, data_sharding)
706+
prompt_mask = jax.device_put(prompt_mask, data_sharding)
707+
negative_prompt_mask = jax.device_put(negative_prompt_mask, data_sharding)
694708

695-
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size
709+
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size, prompt_mask, negative_prompt_mask
696710

697711
def _prepare_model_inputs(
698712
self,
@@ -724,7 +738,7 @@ def _prepare_model_inputs(
724738
batch_size = len(prompt)
725739

726740
with jax.named_scope("Encode-Prompt"):
727-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
741+
prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self.encode_prompt(
728742
prompt=prompt,
729743
negative_prompt=negative_prompt,
730744
max_sequence_length=max_sequence_length,
@@ -752,12 +766,14 @@ def _prepare_model_inputs(
752766
latents = jax.device_put(latents, data_sharding)
753767
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
754768
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
769+
prompt_mask = jax.device_put(prompt_mask, data_sharding)
770+
negative_prompt_mask = jax.device_put(negative_prompt_mask, data_sharding)
755771

756772
scheduler_state = self.scheduler.set_timesteps(
757773
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
758774
)
759775

760-
return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames
776+
return latents, prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask, scheduler_state, num_frames
761777

762778
@abstractmethod
763779
def __call__(self, **kwargs):
@@ -782,6 +798,7 @@ def transformer_forward_pass(
782798
kv_cache=None,
783799
rotary_emb=None,
784800
encoder_attention_mask=None,
801+
text_mask=None,
785802
):
786803
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
787804
outputs = wan_transformer(
@@ -795,6 +812,7 @@ def transformer_forward_pass(
795812
kv_cache=kv_cache,
796813
rotary_emb=rotary_emb,
797814
encoder_attention_mask=encoder_attention_mask,
815+
text_mask=text_mask,
798816
)
799817

800818
if return_residual:
@@ -828,6 +846,7 @@ def transformer_forward_pass_full_cfg(
828846
kv_cache=None,
829847
rotary_emb=None,
830848
encoder_attention_mask=None,
849+
text_mask=None,
831850
):
832851
"""Full CFG forward pass.
833852
@@ -849,6 +868,7 @@ def transformer_forward_pass_full_cfg(
849868
kv_cache=kv_cache,
850869
rotary_emb=rotary_emb,
851870
encoder_attention_mask=encoder_attention_mask,
871+
text_mask=text_mask,
852872
)
853873
noise_cond = noise_pred[:bsz]
854874
noise_uncond = noise_pred[bsz:]
@@ -873,6 +893,7 @@ def transformer_forward_pass_cfg_cache(
873893
kv_cache=None,
874894
rotary_emb=None,
875895
encoder_attention_mask=None,
896+
text_mask=None,
876897
):
877898
"""CFG-Cache forward pass with FFT frequency-domain compensation.
878899
@@ -901,6 +922,7 @@ def transformer_forward_pass_cfg_cache(
901922
kv_cache=kv_cache,
902923
rotary_emb=rotary_emb,
903924
encoder_attention_mask=encoder_attention_mask,
925+
text_mask=text_mask,
904926
)
905927

906928
# FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W]

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __call__(
111111
"CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0."
112112
)
113113

114-
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
114+
latents, prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask, scheduler_state, num_frames = self._prepare_model_inputs(
115115
prompt,
116116
negative_prompt,
117117
height,
@@ -152,6 +152,8 @@ def __call__(
152152
latents=latents,
153153
prompt_embeds=prompt_embeds,
154154
negative_prompt_embeds=negative_prompt_embeds,
155+
prompt_mask=prompt_mask,
156+
negative_prompt_mask=negative_prompt_mask,
155157
)
156158
latents = self._denormalize_latents(latents)
157159
return self._decode_latents_to_video(latents)
@@ -164,6 +166,8 @@ def run_inference_2_1(
164166
latents: jnp.array,
165167
prompt_embeds: jnp.array,
166168
negative_prompt_embeds: jnp.array,
169+
prompt_mask: jnp.array,
170+
negative_prompt_mask: jnp.array,
167171
guidance_scale: float,
168172
num_inference_steps: int,
169173
scheduler: FlaxUniPCMultistepScheduler,
@@ -216,8 +220,12 @@ def run_inference_2_1(
216220
# Pre-split embeds once, outside the loop.
217221
prompt_cond_embeds = prompt_embeds
218222
prompt_embeds_combined = None
223+
prompt_mask_combined = None
219224
if do_cfg:
220225
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
226+
prompt_mask_combined = jnp.concatenate([prompt_mask, negative_prompt_mask], axis=0)
227+
else:
228+
prompt_mask_combined = prompt_mask
221229

222230
# Pre-compute cache schedule and phase-dependent weights.
223231
# t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
@@ -257,7 +265,9 @@ def run_inference_2_1(
257265
encoder_attention_mask = None
258266

259267
if use_kv_cache:
260-
kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache(prompt_embeds_combined if do_cfg else prompt_cond_embeds)
268+
kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache(prompt_embeds_combined if do_cfg else prompt_cond_embeds, text_mask=prompt_mask_combined)
269+
else:
270+
encoder_attention_mask = prompt_mask_combined
261271

262272
if use_magcache and do_cfg:
263273
magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __call__(
131131
"SenCache requires classifier-free guidance to be enabled for both transformer phases."
132132
)
133133

134-
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
134+
latents, prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask, scheduler_state, num_frames = self._prepare_model_inputs(
135135
prompt,
136136
negative_prompt,
137137
height,
@@ -176,6 +176,8 @@ def __call__(
176176
latents=latents,
177177
prompt_embeds=prompt_embeds,
178178
negative_prompt_embeds=negative_prompt_embeds,
179+
prompt_mask=prompt_mask,
180+
negative_prompt_mask=negative_prompt_mask,
179181
)
180182
latents = self._denormalize_latents(latents)
181183
return self._decode_latents_to_video(latents)
@@ -191,6 +193,8 @@ def run_inference_2_2(
191193
latents: jnp.array,
192194
prompt_embeds: jnp.array,
193195
negative_prompt_embeds: jnp.array,
196+
prompt_mask: jnp.array,
197+
negative_prompt_mask: jnp.array,
194198
guidance_scale_low: float,
195199
guidance_scale_high: float,
196200
boundary: int,
@@ -223,6 +227,9 @@ def run_inference_2_2(
223227
prompt_embeds_combined = (
224228
jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds
225229
)
230+
prompt_mask_combined = (
231+
jnp.concatenate([prompt_mask, negative_prompt_mask], axis=0) if do_classifier_free_guidance else prompt_mask
232+
)
226233

227234
low_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest)
228235

@@ -236,10 +243,13 @@ def run_inference_2_2(
236243
encoder_attention_mask_high = None
237244

238245
if use_kv_cache:
239-
kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache(prompt_embeds_combined)
246+
kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache(prompt_embeds_combined, text_mask=prompt_mask_combined)
240247

241248
high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest)
242-
kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined)
249+
kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined, text_mask=prompt_mask_combined)
250+
else:
251+
encoder_attention_mask_low = prompt_mask_combined
252+
encoder_attention_mask_high = prompt_mask_combined
243253

244254
# ── SenCache path (arXiv:2602.24208) ──
245255
if use_sen_cache and do_classifier_free_guidance:

0 commit comments

Comments
 (0)