Skip to content

Commit 665fe66

Browse files
committed
timestep error fix wan 2.1
1 parent dfdbcc0 commit 665fe66

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def loop_body(step, vals):
272272
latents_input = jnp.concatenate([latents, latents], axis=0)
273273

274274
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
275+
timestep = jnp.broadcast_to(t, latents_input.shape[0])
275276
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
276277
timestep = jnp.broadcast_to(t, latents.shape[0])
277278

0 commit comments

Comments
 (0)