1313# limitations under the License.
1414
1515from typing import Sequence
16+ import jax
1617import time
1718from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
1819from maxdiffusion import pyconfig
2122
2223def run (config ):
2324 pipeline = WanPipeline .from_pretrained (config )
24-
2525 s0 = time .perf_counter ()
2626 video = pipeline (
2727 prompt = config .prompt ,
@@ -32,17 +32,20 @@ def run(config):
3232 num_inference_steps = config .num_inference_steps ,
3333 guidance_scale = config .guidance_scale ,
3434 )
35+
3536 print ("compile time: " , (time .perf_counter () - s0 ))
37+ export_to_video (video [0 ], "jax_output.mp4" , fps = 16 )
3638 s0 = time .perf_counter ()
37- video = pipeline (
38- prompt = config .prompt ,
39- negative_prompt = config .negative_prompt ,
40- height = config .height ,
41- width = config .width ,
42- num_frames = config .num_frames ,
43- num_inference_steps = config .num_inference_steps ,
44- guidance_scale = config .guidance_scale ,
45- )
39+ with jax .profiler .trace ("/tmp/trace/" ):
40+ video = pipeline (
41+ prompt = config .prompt ,
42+ negative_prompt = config .negative_prompt ,
43+ height = config .height ,
44+ width = config .width ,
45+ num_frames = config .num_frames ,
46+ num_inference_steps = config .num_inference_steps ,
47+ guidance_scale = config .guidance_scale ,
48+ )
4649 print ("generation time: " , (time .perf_counter () - s0 ))
4750 export_to_video (video [0 ], "jax_output.mp4" , fps = 16 )
4851
@@ -51,5 +54,6 @@ def main(argv: Sequence[str]) -> None:
5154 pyconfig .initialize (argv )
5255 run (pyconfig .config )
5356
57+
5458if __name__ == "__main__" :
5559 app .run (main )
0 commit comments