|
25 | 25 | import jax |
26 | 26 | import jax.numpy as jnp |
27 | 27 | from jax.sharding import NamedSharding, PartitionSpec as P |
28 | | -from ...schedulers.scheduling_flow_match_flax import FlaxFlowMatchScheduler |
| 28 | +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler |
29 | 29 | from ...max_utils import randn_tensor |
30 | 30 |
|
31 | 31 | class WanPipelineI2V_2_1(WanPipeline): |
@@ -65,19 +65,6 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t |
65 | 65 | config=config, |
66 | 66 | ) |
67 | 67 | return pipeline, transformer |
68 | | - |
69 | | - @classmethod |
70 | | - def load_scheduler(cls, config): |
71 | | - """Overrides the base scheduler loader to use Flow Matching for I2V.""" |
72 | | - # Wan 2.1 I2V requires Flow Matching with these specific settings: |
73 | | - # shift=1.0, num_train_timesteps=1000, and usually reverse_sigmas=True (1.0 -> 0.0) |
74 | | - scheduler, scheduler_state = FlaxFlowMatchScheduler.from_pretrained( |
75 | | - config.pretrained_model_name_or_path, |
76 | | - subfolder="scheduler", |
77 | | - shift=1.0, |
78 | | - reverse_sigmas=True, |
79 | | - ) |
80 | | - return scheduler, scheduler_state |
81 | 68 |
|
82 | 69 | @classmethod |
83 | 70 | def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): |
@@ -230,12 +217,12 @@ def __call__( |
230 | 217 | self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape |
231 | 218 | ) |
232 | 219 |
|
233 | | - # if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None: |
234 | | - # t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0] |
235 | | - # dummy_noise = jnp.zeros_like(latents) |
236 | | - # # This call initializes the internal state arrays |
237 | | - # step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents) |
238 | | - # scheduler_state = step_output.state |
| 220 | + if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None: |
| 221 | + t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0] |
| 222 | + dummy_noise = jnp.zeros_like(latents) |
| 223 | + # This call initializes the internal state arrays |
| 224 | + step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents) |
| 225 | + scheduler_state = step_output.state |
239 | 226 | graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) |
240 | 227 | data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding)) |
241 | 228 | latents = jax.device_put(latents, data_sharding) |
@@ -288,7 +275,7 @@ def run_inference_2_1_i2v( |
288 | 275 | image_embeds: jnp.array, |
289 | 276 | guidance_scale: float, |
290 | 277 | num_inference_steps: int, |
291 | | - scheduler: FlaxFlowMatchScheduler, |
| 278 | + scheduler: FlaxUniPCMultistepScheduler, |
292 | 279 | scheduler_state, |
293 | 280 | rng: jax.Array, |
294 | 281 | expand_timesteps: bool, |
@@ -336,16 +323,7 @@ def loop_body(step, vals): |
336 | 323 | s=step, |
337 | 324 | std=jnp.std(latents), |
338 | 325 | mean=jnp.mean(latents)) |
339 | | - |
340 | | - step_output = scheduler.step( |
341 | | - state=scheduler_state, |
342 | | - model_output=noise_pred, |
343 | | - timestep=t, |
344 | | - sample=latents, |
345 | | - return_dict=True |
346 | | - ) |
347 | | - |
348 | | - latents = step_output.prev_sample |
| 326 | + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() |
349 | 327 | jax.debug.print("Step {s}: latents_next std={std}, mean={mean}", |
350 | 328 | s=step, |
351 | 329 | std=jnp.std(latents), |
|
0 commit comments