|
19 | 19 | from maxdiffusion import pyconfig, max_logging, max_utils |
20 | 20 | from absl import app |
21 | 21 | from maxdiffusion.utils import export_to_video |
| 22 | +import os |
22 | 23 |
|
23 | 24 | jax.config.update("jax_use_shardy_partitioner", True) |
24 | 25 |
|
| 26 | +def inference_generate_video(config, pipeline, filename_prefix=""): |
| 27 | + s0 = time.perf_counter() |
| 28 | + prompt = [config.prompt] * config.global_batch_size_to_train_on |
| 29 | + negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on |
| 30 | + |
| 31 | + max_logging.log( |
| 32 | + f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}" |
| 33 | + ) |
| 34 | + |
| 35 | + videos = pipeline( |
| 36 | + prompt=prompt, |
| 37 | + negative_prompt=negative_prompt, |
| 38 | + height=config.height, |
| 39 | + width=config.width, |
| 40 | + num_frames=config.num_frames, |
| 41 | + num_inference_steps=config.num_inference_steps, |
| 42 | + guidance_scale=config.guidance_scale, |
| 43 | + ) |
| 44 | + |
| 45 | + print(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}") |
| 46 | + for i in range(len(videos)): |
| 47 | + video_path = os.path.join(config.output_dir, "videos", f"{filename_prefix}wan_output_{config.seed}_{i}.mp4") |
| 48 | + export_to_video(videos[i], video_path, fps=config.fps) |
| 49 | + return |
25 | 50 |
|
26 | 51 | def run(config, pipeline=None, filename_prefix=""): |
27 | 52 | print("seed: ", config.seed) |
@@ -54,7 +79,7 @@ def run(config, pipeline=None, filename_prefix=""): |
54 | 79 | print("compile time: ", (time.perf_counter() - s0)) |
55 | 80 | saved_video_path = [] |
56 | 81 | for i in range(len(videos)): |
57 | | - video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" |
| 82 | + video_path = os.path.join(config.output_dir, "videos", f"{filename_prefix}wan_output_{config.seed}_{i}.mp4") |
58 | 83 | export_to_video(videos[i], video_path, fps=config.fps) |
59 | 84 | saved_video_path.append(video_path) |
60 | 85 |
|
|
0 commit comments