1+ wan_pipeline_2_2 .py
12# Copyright 2025 Google LLC
23#
34# Licensed under the Apache License, Version 2.0 (the "License");
@@ -38,6 +39,7 @@ def __init__(
3839 super ().__init__ (config = config , ** kwargs )
3940 self .low_noise_transformer = low_noise_transformer
4041 self .high_noise_transformer = high_noise_transformer
42+ self .boundary_ratio = config .boundary_ratio
4143
4244 @classmethod
4345 def _load_and_init (cls , config , restored_checkpoint = None , vae_only = False , load_transformer = True ):
@@ -103,7 +105,6 @@ def __call__(
103105 num_inference_steps : int = 50 ,
104106 guidance_scale_low : float = 3.0 ,
105107 guidance_scale_high : float = 4.0 ,
106- boundary : int = 875 ,
107108 num_videos_per_prompt : Optional [int ] = 1 ,
108109 max_sequence_length : int = 512 ,
109110 latents : jax .Array = None ,
@@ -129,11 +130,13 @@ def __call__(
129130 low_noise_graphdef , low_noise_state , low_noise_rest = nnx .split (self .low_noise_transformer , nnx .Param , ...)
130131 high_noise_graphdef , high_noise_state , high_noise_rest = nnx .split (self .high_noise_transformer , nnx .Param , ...)
131132
133+ boundary_timestep = self .boundary_ratio * self .scheduler .config .num_train_timesteps
134+
132135 p_run_inference = partial (
133136 run_inference_2_2 ,
134137 guidance_scale_low = guidance_scale_low ,
135138 guidance_scale_high = guidance_scale_high ,
136- boundary = boundary ,
139+ boundary = boundary_timestep ,
137140 num_inference_steps = num_inference_steps ,
138141 scheduler = self .scheduler ,
139142 scheduler_state = scheduler_state ,
0 commit comments