Skip to content

Commit 731b07b

Browse files
committed
support for wan2.1 in run_inference added
1 parent 1be0361 commit 731b07b

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,11 @@ def high_noise_branch(operands):
728728
latents = jnp.concatenate([latents] * 2)
729729
timestep = jnp.broadcast_to(t, latents.shape[0])
730730

731+
if model_name == "wan2.1":
732+
noise_pred, latents = low_noise_branch((latents, timestep, prompt_embeds))
733+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
734+
continue
735+
731736
use_high_noise = jnp.greater_equal(t, boundary)
732737

733738
noise_pred, latents = jax.lax.cond(

0 commit comments

Comments
 (0)