Skip to content

Commit 7610f13

Browse files
committed
Changed loop body in wan i2v 2.1
1 parent ec9554a commit 7610f13

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def loop_body(step, vals):
272272
latents_input = jnp.concatenate([latents, latents], axis=0)
273273

274274
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
275+
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
275276
timestep = jnp.broadcast_to(t, latents.shape[0])
276277

277278

@@ -282,6 +283,7 @@ def loop_body(step, vals):
282283
guidance_scale=guidance_scale,
283284
encoder_hidden_states_image=image_embeds,
284285
)
286+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
285287

286288
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
287289
return latents, scheduler_state, rng

0 commit comments

Comments
 (0)