Skip to content

Commit 4907d08

Browse files
committed
Trying text_mask 2\3
1 parent 00077c1 commit 4907d08

3 files changed

Lines changed: 25 additions & 8 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,10 @@ def __call__(
11101110
value_proj = checkpoint_name(value_proj, "value_proj")
11111111

11121112
with jax.named_scope("apply_attention"):
1113-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1113+
if is_self_attention:
1114+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1115+
else:
1116+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask)
11141117

11151118
else:
11161119
# NEW PATH for I2V CROSS-ATTENTION

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def __call__(
587587
timestep: jax.Array,
588588
encoder_hidden_states: jax.Array,
589589
encoder_hidden_states_image: Optional[jax.Array] = None,
590+
encoder_attention_mask: Optional[jax.Array] = None,
590591
return_dict: bool = True,
591592
attention_kwargs: Optional[Dict[str, Any]] = None,
592593
deterministic: bool = True,
@@ -606,17 +607,30 @@ def __call__(
606607
hidden_states = self.patch_embedding(hidden_states)
607608
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
608609
with self.conditional_named_scope("condition_embedder"):
609-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask = self.condition_embedder(
610+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, image_attention_mask = self.condition_embedder(
610611
timestep, encoder_hidden_states, encoder_hidden_states_image
611612
)
612613
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
613614

615+
# Handle masks for T2V vs I2V
614616
if encoder_hidden_states_image is not None:
617+
# I2V case: concatenate image and text embeddings
615618
encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1)
616-
if encoder_attention_mask is not None:
617-
text_mask = jnp.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), dtype=jnp.int32)
618-
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
619+
620+
# Build combined mask: [image_mask | text_mask]
621+
if image_attention_mask is not None:
622+
# We have image mask from embedder
623+
if encoder_attention_mask is not None:
624+
# Use passed text mask (from pipeline)
625+
combined_mask = jnp.concatenate([image_attention_mask, encoder_attention_mask], axis=1)
626+
else:
627+
# No text mask passed, use all-ones (old behavior for backward compat)
628+
text_len = encoder_hidden_states.shape[1] - image_attention_mask.shape[1]
629+
text_mask = jnp.ones((encoder_hidden_states.shape[0], text_len), dtype=jnp.int32)
630+
combined_mask = jnp.concatenate([image_attention_mask, text_mask], axis=1)
631+
encoder_attention_mask = combined_mask
619632
encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype)
633+
# For T2V: encoder_attention_mask is already the text mask passed from pipeline
620634

621635
if self.scan_layers:
622636

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -777,13 +777,13 @@ def transformer_forward_pass(
777777
# jax.debug.print("[DEBUG transformer_forward_pass] encoder_attention_mask shape: {}",
778778
# encoder_attention_mask.shape if encoder_attention_mask is not None else "None")
779779

780-
# For now, DON'T pass the mask - just accept it
780+
# Now actually pass the mask to the transformer
781781
noise_pred = wan_transformer(
782782
hidden_states=latents,
783783
timestep=timestep,
784784
encoder_hidden_states=prompt_embeds,
785-
encoder_hidden_states_image=encoder_hidden_states_image
786-
# encoder_attention_mask=encoder_attention_mask # TODO: Add this next
785+
encoder_hidden_states_image=encoder_hidden_states_image,
786+
encoder_attention_mask=encoder_attention_mask
787787
)
788788
if do_classifier_free_guidance:
789789
bsz = latents.shape[0] // 2

0 commit comments

Comments
 (0)