Skip to content

Commit 649eb38

Browse files
committed
text attn fix
1 parent bfec1e4 commit 649eb38

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,10 @@ def __call__(
11101110
value_proj = checkpoint_name(value_proj, "value_proj")
11111111

11121112
with jax.named_scope("apply_attention"):
1113-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=attention_mask)
1113+
if is_self_attention:
1114+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
1115+
else:
1116+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask)
11141117

11151118
else:
11161119
# NEW PATH for I2V CROSS-ATTENTION

0 commit comments

Comments
 (0)