Skip to content

Commit 6c49ce7

Browse files
committed
duplication corrected in wan2.2 i2v
1 parent 20bfaa2 commit 6c49ce7

1 file changed

Lines changed: 3 additions & 11 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,24 +265,16 @@ def loop_body(step, vals):
265265
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
266266

267267
latents_input = latents
268-
condition_input = condition
269-
prompt_embeds_input = prompt_embeds
270-
image_embeds_input = image_embeds
271268
if do_classifier_free_guidance:
272269
latents_input = jnp.concatenate([latents, latents], axis=0)
273-
condition_input = jnp.concatenate([condition, condition], axis=0)
274-
prompt_embeds_input = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
275-
if image_embeds is not None:
276-
image_embeds_input = jnp.concatenate([image_embeds, image_embeds], axis=0)
277-
278270

279271
if expand_timesteps:
280-
latent_model_input = (1 - first_frame_mask) * condition_input + first_frame_mask * latents_input
272+
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents_input
281273
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
282274
timestep = jnp.expand_dims(temp_ts, axis=0)
283275
timestep = jnp.broadcast_to(timestep, (latents_input.shape[0], temp_ts.shape[0]))
284276
else:
285-
latent_model_input = jnp.concatenate([latents_input, condition_input], axis=-1)
277+
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
286278
timestep = jnp.broadcast_to(t, latents_input.shape[0])
287279

288280

@@ -316,7 +308,7 @@ def low_noise_branch(operands):
316308
use_high_noise,
317309
high_noise_branch,
318310
low_noise_branch,
319-
(latent_model_input, timestep, prompt_embeds_input, image_embeds_input)
311+
(latent_model_input, timestep, prompt_embeds, image_embeds)
320312
)
321313

322314
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

0 commit comments

Comments
 (0)