File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -821,14 +821,16 @@ def __call__(
821821 if audio_encoder_attention_mask is not None : print_shape ("audio_encoder_attention_mask input" , audio_encoder_attention_mask )
822822
823823
824- if encoder_attention_mask is not None and encoder_attention_mask .ndim == 2 :
825- encoder_attention_mask = (1 - encoder_attention_mask .astype (self .dtype )) * - 10000.0
826- encoder_attention_mask = jnp .expand_dims (encoder_attention_mask , axis = 1 )
824+ if self .attention_kernel == "dot_product" :
825+ if encoder_attention_mask is not None and encoder_attention_mask .ndim == 2 :
826+ encoder_attention_mask = (1 - encoder_attention_mask .astype (self .dtype )) * - 10000.0
827+ encoder_attention_mask = jnp .expand_dims (encoder_attention_mask , axis = 1 )
828+
829+ if audio_encoder_attention_mask is not None and audio_encoder_attention_mask .ndim == 2 :
830+ audio_encoder_attention_mask = (1 - audio_encoder_attention_mask .astype (self .dtype )) * - 10000.0
831+ audio_encoder_attention_mask = jnp .expand_dims (audio_encoder_attention_mask , axis = 1 )
832+
827833 if encoder_attention_mask is not None : print_shape ("encoder_attention_mask bias" , encoder_attention_mask )
828-
829- if audio_encoder_attention_mask is not None and audio_encoder_attention_mask .ndim == 2 :
830- audio_encoder_attention_mask = (1 - audio_encoder_attention_mask .astype (self .dtype )) * - 10000.0
831- audio_encoder_attention_mask = jnp .expand_dims (audio_encoder_attention_mask , axis = 1 )
832834 if audio_encoder_attention_mask is not None : print_shape ("audio_encoder_attention_mask bias" , audio_encoder_attention_mask )
833835
834836 batch_size = hidden_states .shape [0 ]
You can’t perform that action at this time.
0 commit comments