@@ -125,6 +125,8 @@ def __call__(
125125 latents = latents ,
126126 prompt_embeds = prompt_embeds ,
127127 negative_prompt_embeds = negative_prompt_embeds ,
128+ text_attention_mask = text_attention_mask ,
129+ negative_text_attention_mask = negative_text_attention_mask ,
128130 )
129131 latents = self ._denormalize_latents (latents )
130132 return self ._decode_latents_to_video (latents )
@@ -140,10 +142,22 @@ def run_inference_2_1(
140142 num_inference_steps : int ,
141143 scheduler : FlaxUniPCMultistepScheduler ,
142144 scheduler_state ,
145+ text_attention_mask : Optional [jnp .array ] = None ,
146+ negative_text_attention_mask : Optional [jnp .array ] = None ,
143147):
144148 do_classifier_free_guidance = guidance_scale > 1.0
145149 if do_classifier_free_guidance :
146150 prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
151+ # Concatenate masks for CFG: [positive_mask | negative_mask]
152+ if text_attention_mask is not None and negative_text_attention_mask is not None :
153+ encoder_attention_mask = jnp .concatenate ([text_attention_mask , negative_text_attention_mask ], axis = 0 )
154+ print (f"[DEBUG run_inference_2_1] Concatenated mask shape: { encoder_attention_mask .shape } " )
155+ print (f"[DEBUG run_inference_2_1] Mask sums - pos: { text_attention_mask .sum ()} , neg: { negative_text_attention_mask .sum ()} , combined: { encoder_attention_mask .sum ()} " )
156+ else :
157+ encoder_attention_mask = None
158+ else :
159+ encoder_attention_mask = text_attention_mask
160+
147161 for step in range (num_inference_steps ):
148162 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
149163 if do_classifier_free_guidance :
@@ -159,7 +173,11 @@ def run_inference_2_1(
159173 prompt_embeds ,
160174 do_classifier_free_guidance = do_classifier_free_guidance ,
161175 guidance_scale = guidance_scale ,
176+ encoder_attention_mask = encoder_attention_mask if step == 0 else encoder_attention_mask , # Pass mask
162177 )
178+
179+ if step == 0 :
180+ print (f"[DEBUG run_inference_2_1] Step 0 - passed encoder_attention_mask shape: { encoder_attention_mask .shape if encoder_attention_mask is not None else None } " )
163181
164182 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
165183 return latents
0 commit comments