1- # Copyright 2025 Google LLC
2- #
3- # Licensed under the Apache License, Version 2.0 (the "License");
4- # you may not use this file except in compliance with the License.
5- # You may obtain a copy of the License at
6- #
7- # http://www.apache.org/licenses/LICENSE-2.0
8- #
9- # Unless required by applicable law or agreed to in writing, software
10- # distributed under the License is distributed on an "AS IS" BASIS,
11- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- # See the License for the specific language governing permissions and
13- # limitations under the License.
1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
1416
1517from typing import Sequence
1618import jax
1719import time
1820import os
1921import subprocess
20- from maxdiffusion .checkpointing .wan_checkpointer_2_1 import WanCheckpointer2_1
21- from maxdiffusion .checkpointing .wan_checkpointer_2_2 import WanCheckpointer2_2
22- from maxdiffusion .checkpointing .wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1
23- from maxdiffusion .checkpointing .wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2
2422from maxdiffusion import pyconfig , max_logging , max_utils
2523from absl import app
2624from maxdiffusion .utils import export_to_video
2927import flax
3028from maxdiffusion .common_types import WAN2_1 , WAN2_2
3129from maxdiffusion .loaders .wan_lora_nnx_loader import Wan2_1NNXLoraLoader , Wan2_2NNXLoraLoader
30+ from maxdiffusion .inference .loader import InferenceLoader
31+ from maxdiffusion .inference .runner import DiffusionRunner
3232
3333
3434def upload_video_to_gcs (output_dir : str , video_path : str ):
@@ -84,84 +84,6 @@ def get_git_commit_hash():
8484jax .config .update ("jax_use_shardy_partitioner" , True )
8585
8686
87- def call_pipeline (config , pipeline , prompt , negative_prompt ):
88- model_key = config .model_name
89- model_type = config .model_type
90- if model_type == "I2V" :
91- image = load_image (config .image_url )
92- if model_key == WAN2_1 :
93- return pipeline (
94- prompt = prompt ,
95- image = image ,
96- negative_prompt = negative_prompt ,
97- height = config .height ,
98- width = config .width ,
99- num_frames = config .num_frames ,
100- num_inference_steps = config .num_inference_steps ,
101- guidance_scale = config .guidance_scale ,
102- )
103- elif model_key == WAN2_2 :
104- return pipeline (
105- prompt = prompt ,
106- image = image ,
107- negative_prompt = negative_prompt ,
108- height = config .height ,
109- width = config .width ,
110- num_frames = config .num_frames ,
111- num_inference_steps = config .num_inference_steps ,
112- guidance_scale_low = config .guidance_scale_low ,
113- guidance_scale_high = config .guidance_scale_high ,
114- )
115- else :
116- raise ValueError (f"Unsupported model_name for I2V in config: { model_key } " )
117- elif model_type == "T2V" :
118- if model_key == WAN2_1 :
119- return pipeline (
120- prompt = prompt ,
121- negative_prompt = negative_prompt ,
122- height = config .height ,
123- width = config .width ,
124- num_frames = config .num_frames ,
125- num_inference_steps = config .num_inference_steps ,
126- guidance_scale = config .guidance_scale ,
127- )
128- elif model_key == WAN2_2 :
129- return pipeline (
130- prompt = prompt ,
131- negative_prompt = negative_prompt ,
132- height = config .height ,
133- width = config .width ,
134- num_frames = config .num_frames ,
135- num_inference_steps = config .num_inference_steps ,
136- guidance_scale_low = config .guidance_scale_low ,
137- guidance_scale_high = config .guidance_scale_high ,
138- )
139- else :
140- raise ValueError (f"Unsupported model_name for T2Vin config: { model_key } " )
141-
142-
143- def inference_generate_video (config , pipeline , filename_prefix = "" ):
144- s0 = time .perf_counter ()
145- prompt = [config .prompt ] * config .global_batch_size_to_train_on
146- negative_prompt = [config .negative_prompt ] * config .global_batch_size_to_train_on
147-
148- max_logging .log (
149- f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } , video: { filename_prefix } "
150- )
151-
152- videos = call_pipeline (config , pipeline , prompt , negative_prompt )
153-
154- max_logging .log (f"video { filename_prefix } , compile time: { (time .perf_counter () - s0 )} " )
155- for i in range (len (videos )):
156- video_path = f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4"
157- export_to_video (videos [i ], video_path , fps = config .fps )
158- if config .output_dir .startswith ("gs://" ):
159- upload_video_to_gcs (os .path .join (config .output_dir , config .run_name ), video_path )
160- # Delete local files to avoid storing too manys videos
161- delete_file (f"./{ video_path } " )
162- return
163-
164-
16587def run (config , pipeline = None , filename_prefix = "" , commit_hash = None ):
16688 model_key = config .model_name
16789 writer = max_utils .initialize_summary_writer (config )
@@ -174,23 +96,22 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
17496 else :
17597 max_logging .log ("Could not retrieve Git commit hash." )
17698
99+ loaded_model = None
177100 if pipeline is None :
178- model_type = config .model_type
179- if model_key == WAN2_1 :
180- if model_type == "I2V" :
181- checkpoint_loader = WanCheckpointerI2V_2_1 (config = config )
182- else :
183- checkpoint_loader = WanCheckpointer2_1 (config = config )
184- elif model_key == WAN2_2 :
185- if model_type == "I2V" :
186- checkpoint_loader = WanCheckpointerI2V_2_2 (config = config )
187- else :
188- checkpoint_loader = WanCheckpointer2_2 (config = config )
189- else :
190- raise ValueError (f"Unsupported model_name for checkpointer: { model_key } " )
191- pipeline , _ , _ = checkpoint_loader .load_checkpoint ()
101+ max_logging .log ("Initializing InferenceLoader..." )
102+ loaded_model = InferenceLoader .load (config )
103+ pipeline = loaded_model ["pipeline" ]
104+ else :
105+ # If pipeline passed explicitly (e.g. from test), wrap it
106+ # But InferenceLoader logic assumes it creates it.
107+ # We construct a dummy loaded_model dict
108+ loaded_model = {
109+ "pipeline" : pipeline ,
110+ "mesh" : getattr (config , "mesh" , None ) # Fallback
111+ }
192112
193113 # If LoRA is specified, inject layers and load weights.
114+ # TODO: Move this into InferenceLoader._load_wan eventually
194115 if (
195116 config .enable_lora
196117 and hasattr (config , "lora_config" )
@@ -225,17 +146,22 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
225146 scan_layers = config .scan_layers ,
226147 dtype = config .weights_dtype ,
227148 )
149+ # Update loaded model with modified pipeline
150+ loaded_model ["pipeline" ] = pipeline
228151
229152 s0 = time .perf_counter ()
230153
231- # Using global_batch_size_to_train_on so not to create more config variables
232- prompt = [config .prompt ] * config .global_batch_size_to_train_on
233- negative_prompt = [config .negative_prompt ] * config .global_batch_size_to_train_on
154+ max_logging .log ("Initializing DiffusionRunner..." )
155+ runner = DiffusionRunner (loaded_model , config )
234156
235157 max_logging .log (
236158 f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } "
237159 )
238- videos = call_pipeline (config , pipeline , prompt , negative_prompt )
160+
161+ # Using global_batch_size_to_train_on logic is handled by Runner/Pipeline mostly now
162+ # But we can override args
163+
164+ videos = runner .run ()
239165
240166 max_logging .log ("===================== Model details =======================" )
241167 max_logging .log (f"model name: { config .model_name } " )
@@ -257,9 +183,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
257183 saved_video_path .append (video_path )
258184 if config .output_dir .startswith ("gs://" ):
259185 upload_video_to_gcs (os .path .join (config .output_dir , config .run_name ), video_path )
186+ delete_file (f"./{ video_path } " )
260187
261188 s0 = time .perf_counter ()
262- videos = call_pipeline ( config , pipeline , prompt , negative_prompt )
189+ videos = runner . run ( )
263190 generation_time = time .perf_counter () - s0
264191 max_logging .log (f"generation_time: { generation_time } " )
265192 if writer and jax .process_index () == 0 :
@@ -272,10 +199,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
272199 max_logging .log (f"generation time per video: { generation_time_per_video } " )
273200 else :
274201 max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
202+
275203 s0 = time .perf_counter ()
276204 if config .enable_profiler :
277205 max_utils .activate_profiler (config )
278- videos = call_pipeline ( config , pipeline , prompt , negative_prompt )
206+ videos = runner . run ( )
279207 max_utils .deactivate_profiler (config )
280208 generation_time_with_profiler = time .perf_counter () - s0
281209 max_logging .log (f"generation_time_with_profiler: { generation_time_with_profiler } " )
@@ -296,4 +224,4 @@ def main(argv: Sequence[str]) -> None:
296224
297225
298226if __name__ == "__main__" :
299- app .run (main )
227+ app .run (main )
0 commit comments