@@ -1072,6 +1072,7 @@ def __call__(
10721072 hidden_states : jax .Array ,
10731073 encoder_hidden_states : jax .Array = None ,
10741074 rotary_emb : Optional [jax .Array ] = None ,
1075+ encoder_attention_mask : Optional [jax .Array ] = None ,
10751076 deterministic : bool = True ,
10761077 rngs : nnx .Rngs = None ,
10771078 ) -> jax .Array :
@@ -1111,7 +1112,7 @@ def __call__(
11111112 value_proj = checkpoint_name (value_proj , "value_proj" )
11121113
11131114 with jax .named_scope ("apply_attention" ):
1114- attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj )
1115+ attn_output = self .attention_op .apply_attention (query_proj , key_proj , value_proj , attention_mask = encoder_attention_mask )
11151116
11161117 else :
11171118 # NEW PATH for I2V CROSS-ATTENTION
@@ -1131,9 +1132,14 @@ def __call__(
11311132 encoder_hidden_states_img = encoder_hidden_states [:, :padded_img_len , :]
11321133 encoder_hidden_states_text = encoder_hidden_states [:, padded_img_len :, :]
11331134
1134- encoder_attention_mask_img = jnp .ones ((encoder_hidden_states_img .shape [0 ], padded_img_len ), dtype = jnp .int32 )
1135- if image_seq_len_actual < padded_img_len :
1136- encoder_attention_mask_img = encoder_attention_mask_img .at [:, image_seq_len_actual :].set (0 )
1135+ # Use the passed encoder_attention_mask, which already contains both image and text masks
1136+ if encoder_attention_mask is not None :
1137+ encoder_attention_mask_img = encoder_attention_mask [:, :padded_img_len ]
1138+ encoder_attention_mask_text = encoder_attention_mask [:, padded_img_len :]
1139+ else :
1140+ # Fallback: if no mask passed, treat all as valid (shouldn't happen with our fix)
1141+ encoder_attention_mask_img = None
1142+ encoder_attention_mask_text = None
11371143 else :
11381144 # If no image_seq_len is specified, treat all as text
11391145 encoder_hidden_states_img = None
@@ -1176,7 +1182,7 @@ def __call__(
11761182
11771183 # Attention - tensors are (B, S, D)
11781184 with self .conditional_named_scope ("cross_attn_text_apply" ):
1179- attn_output_text = self .attention_op .apply_attention (query_proj_text , key_proj_text , value_proj_text )
1185+ attn_output_text = self .attention_op .apply_attention (query_proj_text , key_proj_text , value_proj_text , attention_mask = encoder_attention_mask_text )
11801186 with self .conditional_named_scope ("cross_attn_img_apply" ):
11811187 # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens
11821188 attn_output_img = self .attention_op .apply_attention (query_proj_img , key_proj_img , value_proj_img , attention_mask = encoder_attention_mask_img )
@@ -1189,7 +1195,7 @@ def __call__(
11891195 value_proj_text = checkpoint_name (value_proj_text , "value_proj_text" )
11901196
11911197 with self .conditional_named_scope ("cross_attn_text_apply" ):
1192- attn_output = self .attention_op .apply_attention (query_proj_text , key_proj_text , value_proj_text )
1198+ attn_output = self .attention_op .apply_attention (query_proj_text , key_proj_text , value_proj_text , attention_mask = encoder_attention_mask_text )
11931199
11941200 attn_output = attn_output .astype (dtype = dtype )
11951201 attn_output = checkpoint_name (attn_output , "attn_output" )
0 commit comments