Skip to content

Commit 7543d00

Browse files
committed
removed redundance img attn mask
1 parent 0fb882d commit 7543d00

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ def __call__(
10701070
hidden_states: jax.Array,
10711071
encoder_hidden_states: jax.Array = None,
10721072
rotary_emb: Optional[jax.Array] = None,
1073+
encoder_attention_mask: Optional[jax.Array] = None,
10731074
deterministic: bool = True,
10741075
rngs: nnx.Rngs = None,
10751076
) -> jax.Array:
@@ -1129,9 +1130,13 @@ def __call__(
11291130
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
11301131
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
11311132

1132-
encoder_attention_mask_img = jnp.ones((encoder_hidden_states_img.shape[0], padded_img_len), dtype=jnp.int32)
1133-
if image_seq_len_actual < padded_img_len:
1134-
encoder_attention_mask_img = encoder_attention_mask_img.at[:, image_seq_len_actual:].set(0)
1133+
# Use the passed encoder_attention_mask (created in embeddings_flax.py)
1134+
# It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
1135+
if encoder_attention_mask is not None:
1136+
encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
1137+
else:
1138+
# Fallback: no mask means treat all as valid
1139+
encoder_attention_mask_img = None
11351140
else:
11361141
# If no image_seq_len is specified, treat all as text
11371142
encoder_hidden_states_img = None

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def __call__(
373373
rotary_emb: jax.Array,
374374
deterministic: bool = True,
375375
rngs: nnx.Rngs = None,
376+
encoder_attention_mask: Optional[jax.Array] = None,
376377
):
377378
with self.conditional_named_scope("transformer_block"):
378379
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
@@ -409,6 +410,7 @@ def __call__(
409410
encoder_hidden_states=encoder_hidden_states,
410411
deterministic=deterministic,
411412
rngs=rngs,
413+
encoder_attention_mask = encoder_attention_mask
412414
)
413415
with self.conditional_named_scope("cross_attn_residual"):
414416
hidden_states = hidden_states + attn_output
@@ -621,7 +623,7 @@ def __call__(
621623
def scan_fn(carry, block):
622624
hidden_states_carry, rngs_carry = carry
623625
hidden_states = block(
624-
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry
626+
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry, encoder_attention_mask
625627
)
626628
new_carry = (hidden_states, rngs_carry)
627629
return new_carry, None

0 commit comments

Comments
 (0)