@@ -614,6 +614,7 @@ def _apply_attention(
614614 attention_kernel ,
615615 mask_padding_tokens = mask_padding_tokens ,
616616 residual_checkpoint_name = residual_checkpoint_name ,
617+ attention_mask = attention_mask ,
617618 )
618619 elif "ring" in attention_kernel :
619620 return _tpu_flash_attention (
@@ -628,6 +629,7 @@ def _apply_attention(
628629 dtype ,
629630 attention_kernel ,
630631 mask_padding_tokens = mask_padding_tokens ,
632+ attention_mask = attention_mask ,
631633 )
632634 elif attention_kernel == "cudnn_flash_te" :
633635 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
@@ -1218,7 +1220,9 @@ def __call__(
12181220 value_proj = checkpoint_name (value_proj , "value_proj" )
12191221
12201222 with jax .named_scope ("apply_attention" ):
1221- attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
1223+ attn_output = self .attention_op .apply_attention (
1224+ query_proj , key_proj , value_proj , attention_mask = encoder_attention_mask
1225+ )
12221226
12231227 else :
12241228 # NEW PATH for I2V CROSS-ATTENTION
@@ -1462,7 +1466,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
14621466 key_proj = key_proj .transpose (0 , 2 , 1 , 3 ).reshape (key_proj .shape [0 ], key_proj .shape [2 ], - 1 )
14631467 value_proj = value_proj .transpose (0 , 2 , 1 , 3 ).reshape (value_proj .shape [0 ], value_proj .shape [2 ], - 1 )
14641468
1465- attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
1469+ attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj , attention_mask = attention_mask )
14661470 context_attn_output = None
14671471
14681472 if encoder_hidden_states is not None :
0 commit comments