Skip to content

Commit 72373a0

Browse files
committed
scheduler changed
1 parent 390e113 commit 72373a0

1 file changed

Lines changed: 25 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import jax
2626
import jax.numpy as jnp
2727
from jax.sharding import NamedSharding, PartitionSpec as P
28-
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
28+
from ...schedulers.scheduling_flow_match_flax import FlaxFlowMatchScheduler
2929
from ...max_utils import randn_tensor
3030

3131
class WanPipelineI2V_2_1(WanPipeline):
@@ -65,6 +65,19 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6565
config=config,
6666
)
6767
return pipeline, transformer
68+
69+
@classmethod
70+
def load_scheduler(cls, config):
71+
"""Overrides the base scheduler loader to use Flow Matching for I2V."""
72+
# Wan 2.1 I2V requires Flow Matching with these specific settings:
73+
# shift=1.0, num_train_timesteps=1000, and usually reverse_sigmas=True (1.0 -> 0.0)
74+
scheduler, scheduler_state = FlaxFlowMatchScheduler.from_pretrained(
75+
config.pretrained_model_name_or_path,
76+
subfolder="scheduler",
77+
shift=1.0,
78+
reverse_sigmas=True,
79+
)
80+
return scheduler, scheduler_state
6881

6982
@classmethod
7083
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
@@ -275,7 +288,7 @@ def run_inference_2_1_i2v(
275288
image_embeds: jnp.array,
276289
guidance_scale: float,
277290
num_inference_steps: int,
278-
scheduler: FlaxUniPCMultistepScheduler,
291+
scheduler: FlaxFlowMatchScheduler,
279292
scheduler_state,
280293
rng: jax.Array,
281294
expand_timesteps: bool,
@@ -323,7 +336,16 @@ def loop_body(step, vals):
323336
s=step,
324337
std=jnp.std(latents),
325338
mean=jnp.mean(latents))
326-
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
339+
340+
step_output = scheduler.step(
341+
state=scheduler_state,
342+
model_output=noise_pred,
343+
timestep=t,
344+
sample=latents,
345+
return_dict=True
346+
)
347+
348+
latents = step_output.prev_sample
327349
jax.debug.print("Step {s}: latents_next std={std}, mean={mean}",
328350
s=step,
329351
std=jnp.std(latents),

0 commit comments

Comments
 (0)