Skip to content

Commit 20bfaa2

Browse files
committed
duplication corrected in wan2.1 i2v
1 parent e76dd1e commit 20bfaa2

1 file changed

Lines changed: 2 additions & 9 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -269,30 +269,23 @@ def loop_body(step, vals):
269269

270270
latents_input = latents
271271
if do_classifier_free_guidance:
272-
condition_input = jnp.concatenate([condition, condition], axis=0)
273272
latents_input = jnp.concatenate([latents, latents], axis=0)
274-
else:
275-
condition_input = condition
276273

277-
latent_model_input = jnp.concatenate([latents_input, condition_input], axis=-1)
274+
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
278275
timestep = jnp.broadcast_to(t, latents_input.shape[0])
279276
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
280277

281278
prompt_embeds_input = prompt_embeds
282-
image_embeds_input = image_embeds
283279
if do_classifier_free_guidance:
284280
prompt_embeds_input = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
285-
if image_embeds is not None:
286-
image_embeds_input = jnp.concatenate([image_embeds, image_embeds], axis=0)
287-
288281

289282

290283
noise_pred, latents = transformer_forward_pass(
291284
graphdef, sharded_state, rest_of_state,
292285
latent_model_input, timestep, prompt_embeds_input,
293286
do_classifier_free_guidance=do_classifier_free_guidance,
294287
guidance_scale=guidance_scale,
295-
encoder_hidden_states_image=image_embeds_input,
288+
encoder_hidden_states_image=image_embeds,
296289
)
297290
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
298291

0 commit comments

Comments
 (0)