Skip to content

Commit 0f88829

Browse files
committed
replaced boundary_timestep with ratio in wan2.2 t2v
1 parent 7171454 commit 0f88829

3 files changed

Lines changed: 6 additions & 4 deletions

File tree

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ guidance_scale_high: 4.0
300300
# The timestep threshold. If `t` is at or above this value,
301301
# the `high_noise_model` is considered as the required model.
302302
# timestep to switch between low noise and high noise transformer
303-
boundary_timestep: 875
303+
boundary_ratio: 0.875
304304

305305
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
306306
guidance_rescale: 0.0

src/maxdiffusion/generate_wan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
134134
num_inference_steps=config.num_inference_steps,
135135
guidance_scale_low=config.guidance_scale_low,
136136
guidance_scale_high=config.guidance_scale_high,
137-
boundary=config.boundary_timestep,
138137
)
139138
else:
140139
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
wan_pipeline_2_2.py
12
# Copyright 2025 Google LLC
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -38,6 +39,7 @@ def __init__(
3839
super().__init__(config=config, **kwargs)
3940
self.low_noise_transformer = low_noise_transformer
4041
self.high_noise_transformer = high_noise_transformer
42+
self.boundary_ratio = config.boundary_ratio
4143

4244
@classmethod
4345
def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True):
@@ -103,7 +105,6 @@ def __call__(
103105
num_inference_steps: int = 50,
104106
guidance_scale_low: float = 3.0,
105107
guidance_scale_high: float = 4.0,
106-
boundary: int = 875,
107108
num_videos_per_prompt: Optional[int] = 1,
108109
max_sequence_length: int = 512,
109110
latents: jax.Array = None,
@@ -129,11 +130,13 @@ def __call__(
129130
low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...)
130131
high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...)
131132

133+
boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps
134+
132135
p_run_inference = partial(
133136
run_inference_2_2,
134137
guidance_scale_low=guidance_scale_low,
135138
guidance_scale_high=guidance_scale_high,
136-
boundary=boundary,
139+
boundary=boundary_timestep,
137140
num_inference_steps=num_inference_steps,
138141
scheduler=self.scheduler,
139142
scheduler_state=scheduler_state,

0 commit comments

Comments
 (0)