Skip to content

Commit 231b379

Browse files
committed
removed redundance img attn mask
1 parent 1e92718 commit 231b379

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ def __call__(
10381038
hidden_states: jax.Array,
10391039
encoder_hidden_states: jax.Array = None,
10401040
rotary_emb: Optional[jax.Array] = None,
1041+
encoder_attention_mask: Optional[jax.Array] = None,
10411042
deterministic: bool = True,
10421043
rngs: nnx.Rngs = None,
10431044
) -> jax.Array:
@@ -1097,9 +1098,13 @@ def __call__(
10971098
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
10981099
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
10991100

1100-
encoder_attention_mask_img = jnp.ones((encoder_hidden_states_img.shape[0], padded_img_len), dtype=jnp.int32)
1101-
if image_seq_len_actual < padded_img_len:
1102-
encoder_attention_mask_img = encoder_attention_mask_img.at[:, image_seq_len_actual:].set(0)
1101+
# Use the passed encoder_attention_mask (created in embeddings_flax.py)
1102+
# It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384
1103+
if encoder_attention_mask is not None:
1104+
encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
1105+
else:
1106+
# Fallback: no mask means treat all as valid
1107+
encoder_attention_mask_img = None
11031108
else:
11041109
# If no image_seq_len is specified, treat all as text
11051110
encoder_hidden_states_img = None

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def __call__(
373373
rotary_emb: jax.Array,
374374
deterministic: bool = True,
375375
rngs: nnx.Rngs = None,
376+
encoder_attention_mask: Optional[jax.Array] = None,
376377
):
377378
with self.conditional_named_scope("transformer_block"):
378379
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
@@ -409,6 +410,7 @@ def __call__(
409410
encoder_hidden_states=encoder_hidden_states,
410411
deterministic=deterministic,
411412
rngs=rngs,
413+
encoder_attention_mask = encoder_attention_mask
412414
)
413415
with self.conditional_named_scope("cross_attn_residual"):
414416
hidden_states = hidden_states + attn_output
@@ -621,7 +623,7 @@ def __call__(
621623
def scan_fn(carry, block):
622624
hidden_states_carry, rngs_carry = carry
623625
hidden_states = block(
624-
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry
626+
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry, encoder_attention_mask
625627
)
626628
new_carry = (hidden_states, rngs_carry)
627629
return new_carry, None

0 commit comments

Comments
 (0)