Skip to content

Commit 8393ec4

Browse files
committed
text attn mask fix
1 parent 3538f1a commit 8393ec4

1 file changed

Lines changed: 12 additions & 6 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,7 @@ def __call__(
10721072
hidden_states: jax.Array,
10731073
encoder_hidden_states: jax.Array = None,
10741074
rotary_emb: Optional[jax.Array] = None,
1075+
encoder_attention_mask: Optional[jax.Array] = None,
10751076
deterministic: bool = True,
10761077
rngs: nnx.Rngs = None,
10771078
) -> jax.Array:
@@ -1111,7 +1112,7 @@ def __call__(
11111112
value_proj = checkpoint_name(value_proj, "value_proj")
11121113

11131114
with jax.named_scope("apply_attention"):
1114-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1115+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask)
11151116

11161117
else:
11171118
# NEW PATH for I2V CROSS-ATTENTION
@@ -1131,9 +1132,14 @@ def __call__(
11311132
encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :]
11321133
encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :]
11331134

1134-
encoder_attention_mask_img = jnp.ones((encoder_hidden_states_img.shape[0], padded_img_len), dtype=jnp.int32)
1135-
if image_seq_len_actual < padded_img_len:
1136-
encoder_attention_mask_img = encoder_attention_mask_img.at[:, image_seq_len_actual:].set(0)
1135+
# Use the passed encoder_attention_mask, which already contains both image and text masks
1136+
if encoder_attention_mask is not None:
1137+
encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len]
1138+
encoder_attention_mask_text = encoder_attention_mask[:, padded_img_len:]
1139+
else:
1140+
# Fallback: if no mask passed, treat all as valid (shouldn't happen with our fix)
1141+
encoder_attention_mask_img = None
1142+
encoder_attention_mask_text = None
11371143
else:
11381144
# If no image_seq_len is specified, treat all as text
11391145
encoder_hidden_states_img = None
@@ -1176,7 +1182,7 @@ def __call__(
11761182

11771183
# Attention - tensors are (B, S, D)
11781184
with self.conditional_named_scope("cross_attn_text_apply"):
1179-
attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text)
1185+
attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask_text)
11801186
with self.conditional_named_scope("cross_attn_img_apply"):
11811187
# Pass encoder_attention_mask_img for image cross-attention to mask padded tokens
11821188
attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img)
@@ -1189,7 +1195,7 @@ def __call__(
11891195
value_proj_text = checkpoint_name(value_proj_text, "value_proj_text")
11901196

11911197
with self.conditional_named_scope("cross_attn_text_apply"):
1192-
attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text)
1198+
attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask_text)
11931199

11941200
attn_output = attn_output.astype(dtype=dtype)
11951201
attn_output = checkpoint_name(attn_output, "attn_output")

0 commit comments

Comments
 (0)