2020from absl import app
2121from maxdiffusion .utils import export_to_video
2222
23+
2324def run (config ):
2425 print ("seed: " , config .seed )
2526 pipeline = WanPipeline .from_pretrained (config )
2627 s0 = time .perf_counter ()
27-
28+
2829 # Skip layer guidance
2930 slg_layers = config .slg_layers
3031 slg_start = config .slg_start
3132 slg_end = config .slg_end
3233
3334 prompt = [config .prompt ] * jax .device_count ()
34- negative_prompt = [config .negative_prompt ] * jax .device_count ()
35-
36- videos = pipeline (
37- prompt = prompt ,
38- negative_prompt = negative_prompt ,
39- height = config .height ,
40- width = config .width ,
41- num_frames = config .num_frames ,
42- num_inference_steps = config .num_inference_steps ,
43- guidance_scale = config .guidance_scale ,
44- slg_layers = slg_layers ,
45- slg_start = slg_start ,
46- slg_end = slg_end
47- )
35+ negative_prompt = [config .negative_prompt ] * jax .device_count ()
4836
49- print ("compile time: " , (time .perf_counter () - s0 ))
50- for i in range (len (videos )):
51- export_to_video (videos [i ], f"wan_output_{ config .seed } _{ i } .mp4" , fps = config .fps )
52- s0 = time .perf_counter ()
53- with jax .profiler .trace ("/tmp/trace/" ):
54- videos = pipeline (
37+ videos = pipeline (
5538 prompt = prompt ,
5639 negative_prompt = negative_prompt ,
5740 height = config .height ,
@@ -61,7 +44,25 @@ def run(config):
6144 guidance_scale = config .guidance_scale ,
6245 slg_layers = slg_layers ,
6346 slg_start = slg_start ,
64- slg_end = slg_end
47+ slg_end = slg_end ,
48+ )
49+
50+ print ("compile time: " , (time .perf_counter () - s0 ))
51+ for i in range (len (videos )):
52+ export_to_video (videos [i ], f"wan_output_{ config .seed } _{ i } .mp4" , fps = config .fps )
53+ s0 = time .perf_counter ()
54+ with jax .profiler .trace ("/tmp/trace/" ):
55+ videos = pipeline (
56+ prompt = prompt ,
57+ negative_prompt = negative_prompt ,
58+ height = config .height ,
59+ width = config .width ,
60+ num_frames = config .num_frames ,
61+ num_inference_steps = config .num_inference_steps ,
62+ guidance_scale = config .guidance_scale ,
63+ slg_layers = slg_layers ,
64+ slg_start = slg_start ,
65+ slg_end = slg_end ,
6566 )
6667 print ("generation time: " , (time .perf_counter () - s0 ))
6768 for i in range (len (videos )):
@@ -74,4 +75,4 @@ def main(argv: Sequence[str]) -> None:
7475
7576
7677if __name__ == "__main__" :
77- app .run (main )
78+ app .run (main )
0 commit comments