Skip to content

Commit 5a448f3

Browse files
committed
Reverting encoder_mask to None for T2V models
1 parent 6004e92 commit 5a448f3

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,11 @@ def __call__(
11931193

11941194
is_i2v_cross_attention = self.added_kv_proj_dim is not None and not is_self_attention
11951195

1196+
# For T2V self-attention and cross-attention, we skip passing the mask
1197+
# to avoid overhead, as it should be all 1s for unpadded sequences.
1198+
if not is_i2v_cross_attention:
1199+
encoder_attention_mask = None
1200+
11961201
if not is_i2v_cross_attention:
11971202
with jax.named_scope("query_proj"):
11981203
query_proj = self.query(hidden_states)

0 commit comments

Comments
 (0)