We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 13075ad commit 70164a8Copy full SHA for 70164a8
1 file changed
src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py
@@ -276,16 +276,19 @@ def loop_body(step, vals):
276
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
277
278
prompt_embeds_input = prompt_embeds
279
+ image_embeds_input = image_embeds
280
if do_classifier_free_guidance:
281
prompt_embeds_input = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
282
+ if image_embeds is not None:
283
+ image_embeds_input = jnp.concatenate([image_embeds, image_embeds], axis=0)
284
285
286
noise_pred, latents = transformer_forward_pass(
287
graphdef, sharded_state, rest_of_state,
288
latent_model_input, timestep, prompt_embeds_input,
289
do_classifier_free_guidance=do_classifier_free_guidance,
290
guidance_scale=guidance_scale,
- encoder_hidden_states_image=image_embeds,
291
+ encoder_hidden_states_image=image_embeds_input,
292
)
293
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
294
0 commit comments