-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathgenerate_wan.py
More file actions
109 lines (95 loc) · 3.37 KB
/
generate_wan.py
File metadata and controls
109 lines (95 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence
import jax
import time
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
from maxdiffusion import pyconfig, max_logging, max_utils
from absl import app
from maxdiffusion.utils import export_to_video
def run(config, pipeline=None, filename_prefix=""):
print("seed: ", config.seed)
if pipeline is None:
pipeline = WanPipeline.from_pretrained(config)
s0 = time.perf_counter()
# Skip layer guidance
slg_layers = config.slg_layers
slg_start = config.slg_start
slg_end = config.slg_end
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size = config.global_batch_size
if global_batch_size != 0:
batch_multiplier = global_batch_size
else:
batch_multiplier = jax.device_count() * config.per_device_batch_size
prompt = [config.prompt] * batch_multiplier
negative_prompt = [config.negative_prompt] * batch_multiplier
max_logging.log(
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
)
videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)
print("compile time: ", (time.perf_counter() - s0))
saved_video_path = []
for i in range(len(videos)):
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
export_to_video(videos[i], video_path, fps=config.fps)
saved_video_path.append(video_path)
s0 = time.perf_counter()
videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)
print("compile time: ", (time.perf_counter() - s0))
s0 = time.perf_counter()
if config.enable_profiler:
max_utils.activate_profiler(config)
videos = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)
max_utils.deactivate_profiler(config)
print("generation time: ", (time.perf_counter() - s0))
return saved_video_path
def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)
if __name__ == "__main__":
app.run(main)