Skip to content

Commit 00077c1

Browse files
committed
Trying text_mask 2
1 parent acc7452 commit 00077c1

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,9 +769,22 @@ def transformer_forward_pass(
769769
do_classifier_free_guidance,
770770
guidance_scale,
771771
encoder_hidden_states_image=None,
772+
encoder_attention_mask=None,
772773
):
773774
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
774-
noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image)
775+
776+
# DEBUG: Print mask info (only compiles once due to jit)
777+
# jax.debug.print("[DEBUG transformer_forward_pass] encoder_attention_mask shape: {}",
778+
# encoder_attention_mask.shape if encoder_attention_mask is not None else "None")
779+
780+
# For now, DON'T pass the mask - just accept it
781+
noise_pred = wan_transformer(
782+
hidden_states=latents,
783+
timestep=timestep,
784+
encoder_hidden_states=prompt_embeds,
785+
encoder_hidden_states_image=encoder_hidden_states_image
786+
# encoder_attention_mask=encoder_attention_mask # TODO: Add this next
787+
)
775788
if do_classifier_free_guidance:
776789
bsz = latents.shape[0] // 2
777790
noise_cond = noise_pred[:bsz] # First half = conditional

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __call__(
125125
latents=latents,
126126
prompt_embeds=prompt_embeds,
127127
negative_prompt_embeds=negative_prompt_embeds,
128+
text_attention_mask=text_attention_mask,
129+
negative_text_attention_mask=negative_text_attention_mask,
128130
)
129131
latents = self._denormalize_latents(latents)
130132
return self._decode_latents_to_video(latents)
@@ -140,10 +142,22 @@ def run_inference_2_1(
140142
num_inference_steps: int,
141143
scheduler: FlaxUniPCMultistepScheduler,
142144
scheduler_state,
145+
text_attention_mask: Optional[jnp.array] = None,
146+
negative_text_attention_mask: Optional[jnp.array] = None,
143147
):
144148
do_classifier_free_guidance = guidance_scale > 1.0
145149
if do_classifier_free_guidance:
146150
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
151+
# Concatenate masks for CFG: [positive_mask | negative_mask]
152+
if text_attention_mask is not None and negative_text_attention_mask is not None:
153+
encoder_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
154+
print(f"[DEBUG run_inference_2_1] Concatenated mask shape: {encoder_attention_mask.shape}")
155+
print(f"[DEBUG run_inference_2_1] Mask sums - pos: {text_attention_mask.sum()}, neg: {negative_text_attention_mask.sum()}, combined: {encoder_attention_mask.sum()}")
156+
else:
157+
encoder_attention_mask = None
158+
else:
159+
encoder_attention_mask = text_attention_mask
160+
147161
for step in range(num_inference_steps):
148162
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
149163
if do_classifier_free_guidance:
@@ -159,7 +173,11 @@ def run_inference_2_1(
159173
prompt_embeds,
160174
do_classifier_free_guidance=do_classifier_free_guidance,
161175
guidance_scale=guidance_scale,
176+
encoder_attention_mask=encoder_attention_mask if step == 0 else encoder_attention_mask, # Pass mask
162177
)
178+
179+
if step == 0:
180+
print(f"[DEBUG run_inference_2_1] Step 0 - passed encoder_attention_mask shape: {encoder_attention_mask.shape if encoder_attention_mask is not None else None}")
163181

164182
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
165183
return latents

0 commit comments

Comments
 (0)