Skip to content

Commit bea08e3

Browse files
committed
timestep error fix wan 2.1
1 parent 665fe66 commit bea08e3

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,20 +269,30 @@ def loop_body(step, vals):
269269

270270
latents_input = latents
271271
if do_classifier_free_guidance:
272+
condition_input = jnp.concatenate([condition, condition], axis=0)
272273
latents_input = jnp.concatenate([latents, latents], axis=0)
274+
else:
275+
condition_input = condition
273276

274-
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
277+
latent_model_input = jnp.concatenate([latents_input, condition_input], axis=-1)
275278
timestep = jnp.broadcast_to(t, latents_input.shape[0])
276279
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
277-
timestep = jnp.broadcast_to(t, latents.shape[0])
280+
281+
prompt_embeds_input = prompt_embeds
282+
image_embeds_input = image_embeds
283+
if do_classifier_free_guidance:
284+
prompt_embeds_input = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
285+
if image_embeds is not None:
286+
image_embeds_input = jnp.concatenate([image_embeds, image_embeds], axis=0)
287+
278288

279289

280290
noise_pred, latents = transformer_forward_pass(
281291
graphdef, sharded_state, rest_of_state,
282-
latent_model_input, timestep, prompt_embeds,
292+
latent_model_input, timestep, prompt_embeds_input,
283293
do_classifier_free_guidance=do_classifier_free_guidance,
284294
guidance_scale=guidance_scale,
285-
encoder_hidden_states_image=image_embeds,
295+
encoder_hidden_states_image=image_embeds_input,
286296
)
287297
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
288298

0 commit comments

Comments
 (0)