Skip to content

Commit 4957a70

Browse files
committed
passing in attention_mask
1 parent a4c1b6e commit 4957a70

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def _apply_attention(
614614
attention_kernel,
615615
mask_padding_tokens=mask_padding_tokens,
616616
residual_checkpoint_name=residual_checkpoint_name,
617+
attention_mask=attention_mask,
617618
)
618619
elif "ring" in attention_kernel:
619620
return _tpu_flash_attention(
@@ -628,6 +629,7 @@ def _apply_attention(
628629
dtype,
629630
attention_kernel,
630631
mask_padding_tokens=mask_padding_tokens,
632+
attention_mask=attention_mask,
631633
)
632634
elif attention_kernel == "cudnn_flash_te":
633635
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
@@ -1218,7 +1220,9 @@ def __call__(
12181220
value_proj = checkpoint_name(value_proj, "value_proj")
12191221

12201222
with jax.named_scope("apply_attention"):
1221-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1223+
attn_output = self.attention_op.apply_attention(
1224+
query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask
1225+
)
12221226

12231227
else:
12241228
# NEW PATH for I2V CROSS-ATTENTION
@@ -1462,7 +1466,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
14621466
key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1)
14631467
value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1)
14641468

1465-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1469+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=attention_mask)
14661470
context_attn_output = None
14671471

14681472
if encoder_hidden_states is not None:

0 commit comments

Comments
 (0)