|
25 | 25 | import jax |
26 | 26 | import jax.numpy as jnp |
27 | 27 | from jax.sharding import NamedSharding, PartitionSpec as P |
28 | | -from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler |
| 28 | +from ...schedulers.scheduling_flow_match_flax import FlaxFlowMatchScheduler |
29 | 29 | from ...max_utils import randn_tensor |
30 | 30 |
|
31 | 31 | class WanPipelineI2V_2_1(WanPipeline): |
@@ -65,6 +65,19 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t |
65 | 65 | config=config, |
66 | 66 | ) |
67 | 67 | return pipeline, transformer |
| 68 | + |
| 69 | + @classmethod |
| 70 | + def load_scheduler(cls, config): |
| 71 | + """Overrides the base scheduler loader to use Flow Matching for I2V.""" |
| 72 | + # Wan 2.1 I2V requires Flow Matching with these specific settings: |
| 73 | + # shift=1.0, num_train_timesteps=1000, and usually reverse_sigmas=True (1.0 -> 0.0) |
| 74 | + scheduler, scheduler_state = FlaxFlowMatchScheduler.from_pretrained( |
| 75 | + config.pretrained_model_name_or_path, |
| 76 | + subfolder="scheduler", |
| 77 | + shift=1.0, |
| 78 | + reverse_sigmas=True, |
| 79 | + ) |
| 80 | + return scheduler, scheduler_state |
68 | 81 |
|
69 | 82 | @classmethod |
70 | 83 | def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): |
@@ -275,7 +288,7 @@ def run_inference_2_1_i2v( |
275 | 288 | image_embeds: jnp.array, |
276 | 289 | guidance_scale: float, |
277 | 290 | num_inference_steps: int, |
278 | | - scheduler: FlaxUniPCMultistepScheduler, |
| 291 | + scheduler: FlaxFlowMatchScheduler, |
279 | 292 | scheduler_state, |
280 | 293 | rng: jax.Array, |
281 | 294 | expand_timesteps: bool, |
@@ -323,7 +336,16 @@ def loop_body(step, vals): |
323 | 336 | s=step, |
324 | 337 | std=jnp.std(latents), |
325 | 338 | mean=jnp.mean(latents)) |
326 | | - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() |
| 339 | + |
| 340 | + step_output = scheduler.step( |
| 341 | + state=scheduler_state, |
| 342 | + model_output=noise_pred, |
| 343 | + timestep=t, |
| 344 | + sample=latents, |
| 345 | + return_dict=True |
| 346 | + ) |
| 347 | + |
| 348 | + latents = step_output.prev_sample |
327 | 349 | jax.debug.print("Step {s}: latents_next std={std}, mean={mean}", |
328 | 350 | s=step, |
329 | 351 | std=jnp.std(latents), |
|
0 commit comments