Skip to content

Commit 70164a8

Browse files
committed
image embeds duplicated in 2.1
1 parent 13075ad commit 70164a8

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,16 +276,19 @@ def loop_body(step, vals):
276276
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
277277

278278
prompt_embeds_input = prompt_embeds
279+
image_embeds_input = image_embeds
279280
if do_classifier_free_guidance:
280281
prompt_embeds_input = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
282+
if image_embeds is not None:
283+
image_embeds_input = jnp.concatenate([image_embeds, image_embeds], axis=0)
281284

282285

283286
noise_pred, latents = transformer_forward_pass(
284287
graphdef, sharded_state, rest_of_state,
285288
latent_model_input, timestep, prompt_embeds_input,
286289
do_classifier_free_guidance=do_classifier_free_guidance,
287290
guidance_scale=guidance_scale,
288-
encoder_hidden_states_image=image_embeds,
291+
encoder_hidden_states_image=image_embeds_input,
289292
)
290293
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
291294

0 commit comments

Comments
 (0)