Skip to content

Commit 8b9f858

Browse files
committed
Trying text_mask 11
1 parent dc8f565 commit 8b9f858

1 file changed

Lines changed: 1 addition & 7 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def run_inference_2_2_i2v(
273273
encoder_attention_mask = jnp.concatenate([text_attention_mask, negative_text_attention_mask], axis=0)
274274
else:
275275
encoder_attention_mask = None
276+
# WAN 2.2 I2V doesn't use CLIP image embeddings, so image_embeds may be None
276277
if image_embeds is not None:
277278
image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0)
278279
condition = jnp.concatenate([condition] * 2)
@@ -301,13 +302,6 @@ def low_noise_branch(operands):
301302
)
302303
return noise_pred, latents_out
303304

304-
if do_classifier_free_guidance:
305-
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
306-
# WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder
307-
if image_embeds is not None:
308-
image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0)
309-
condition = jnp.concatenate([condition] * 2)
310-
311305
for step in range(num_inference_steps):
312306
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
313307
latents_input = latents

0 commit comments

Comments
 (0)