Skip to content

Commit 7a6bfec

Browse files
committed
fix for tpu flash attention
1 parent abc1575 commit 7a6bfec

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -821,14 +821,16 @@ def __call__(
821821
if audio_encoder_attention_mask is not None: print_shape("audio_encoder_attention_mask input", audio_encoder_attention_mask)
822822

823823

824-
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
825-
encoder_attention_mask = (1 - encoder_attention_mask.astype(self.dtype)) * -10000.0
826-
encoder_attention_mask = jnp.expand_dims(encoder_attention_mask, axis=1)
824+
if self.attention_kernel == "dot_product":
825+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
826+
encoder_attention_mask = (1 - encoder_attention_mask.astype(self.dtype)) * -10000.0
827+
encoder_attention_mask = jnp.expand_dims(encoder_attention_mask, axis=1)
828+
829+
if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2:
830+
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.astype(self.dtype)) * -10000.0
831+
audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1)
832+
827833
if encoder_attention_mask is not None: print_shape("encoder_attention_mask bias", encoder_attention_mask)
828-
829-
if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2:
830-
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.astype(self.dtype)) * -10000.0
831-
audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1)
832834
if audio_encoder_attention_mask is not None: print_shape("audio_encoder_attention_mask bias", audio_encoder_attention_mask)
833835

834836
batch_size = hidden_states.shape[0]

0 commit comments

Comments
 (0)