Skip to content

Commit 3e6b745

Browse files
committed
scheduler reverted to original
1 parent 5a71b21 commit 3e6b745

1 file changed

Lines changed: 9 additions & 31 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import jax
2626
import jax.numpy as jnp
2727
from jax.sharding import NamedSharding, PartitionSpec as P
28-
from ...schedulers.scheduling_flow_match_flax import FlaxFlowMatchScheduler
28+
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2929
from ...max_utils import randn_tensor
3030

3131
class WanPipelineI2V_2_1(WanPipeline):
@@ -65,19 +65,6 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6565
config=config,
6666
)
6767
return pipeline, transformer
68-
69-
@classmethod
70-
def load_scheduler(cls, config):
71-
"""Overrides the base scheduler loader to use Flow Matching for I2V."""
72-
# Wan 2.1 I2V requires Flow Matching with these specific settings:
73-
# shift=1.0, num_train_timesteps=1000, and usually reverse_sigmas=True (1.0 -> 0.0)
74-
scheduler, scheduler_state = FlaxFlowMatchScheduler.from_pretrained(
75-
config.pretrained_model_name_or_path,
76-
subfolder="scheduler",
77-
shift=1.0,
78-
reverse_sigmas=True,
79-
)
80-
return scheduler, scheduler_state
8168

8269
@classmethod
8370
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
@@ -230,12 +217,12 @@ def __call__(
230217
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
231218
)
232219

233-
# if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None:
234-
# t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0]
235-
# dummy_noise = jnp.zeros_like(latents)
236-
# # This call initializes the internal state arrays
237-
# step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
238-
# scheduler_state = step_output.state
220+
if self.scheduler_state.last_sample is None or self.scheduler_state.step_index is None:
221+
t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0]
222+
dummy_noise = jnp.zeros_like(latents)
223+
# This call initializes the internal state arrays
224+
step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
225+
scheduler_state = step_output.state
239226
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
240227
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
241228
latents = jax.device_put(latents, data_sharding)
@@ -288,7 +275,7 @@ def run_inference_2_1_i2v(
288275
image_embeds: jnp.array,
289276
guidance_scale: float,
290277
num_inference_steps: int,
291-
scheduler: FlaxFlowMatchScheduler,
278+
scheduler: FlaxUniPCMultistepScheduler,
292279
scheduler_state,
293280
rng: jax.Array,
294281
expand_timesteps: bool,
@@ -336,16 +323,7 @@ def loop_body(step, vals):
336323
s=step,
337324
std=jnp.std(latents),
338325
mean=jnp.mean(latents))
339-
340-
step_output = scheduler.step(
341-
state=scheduler_state,
342-
model_output=noise_pred,
343-
timestep=t,
344-
sample=latents,
345-
return_dict=True
346-
)
347-
348-
latents = step_output.prev_sample
326+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
349327
jax.debug.print("Step {s}: latents_next std={std}, mean={mean}",
350328
s=step,
351329
std=jnp.std(latents),

0 commit comments

Comments
 (0)