Skip to content

Commit e76dd1e

Browse files
committed
timestep error fix wan 2.2
1 parent bea08e3 commit e76dd1e

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,24 @@ def loop_body(step, vals):
265265
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
266266

267267
latents_input = latents
268+
condition_input = condition
269+
prompt_embeds_input = prompt_embeds
270+
image_embeds_input = image_embeds
268271
if do_classifier_free_guidance:
269272
latents_input = jnp.concatenate([latents, latents], axis=0)
273+
condition_input = jnp.concatenate([condition, condition], axis=0)
274+
prompt_embeds_input = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
275+
if image_embeds is not None:
276+
image_embeds_input = jnp.concatenate([image_embeds, image_embeds], axis=0)
277+
270278

271279
if expand_timesteps:
272-
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents_input
280+
latent_model_input = (1 - first_frame_mask) * condition_input + first_frame_mask * latents_input
273281
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
274282
timestep = jnp.expand_dims(temp_ts, axis=0)
275283
timestep = jnp.broadcast_to(timestep, (latents_input.shape[0], temp_ts.shape[0]))
276284
else:
277-
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
285+
latent_model_input = jnp.concatenate([latents_input, condition_input], axis=-1)
278286
timestep = jnp.broadcast_to(t, latents_input.shape[0])
279287

280288

@@ -308,7 +316,7 @@ def low_noise_branch(operands):
308316
use_high_noise,
309317
high_noise_branch,
310318
low_noise_branch,
311-
(latent_model_input, timestep, prompt_embeds, image_embeds)
319+
(latent_model_input, timestep, prompt_embeds_input, image_embeds_input)
312320
)
313321

314322
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

0 commit comments

Comments
 (0)