Skip to content

Commit 8186723

Browse files
vij_wan
1 parent 87817d0 commit 8186723

5 files changed

Lines changed: 96 additions & 0 deletions

File tree

lat.npy

10.5 MB
Binary file not shown.

latents.npy

2.64 MB
Binary file not shown.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence
16+
import jax
17+
import time
18+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
19+
from maxdiffusion import pyconfig
20+
from absl import app
21+
from maxdiffusion.utils import export_to_video
22+
23+
def run(config):
24+
pipeline = WanPipeline.from_pretrained(config)
25+
s0 = time.perf_counter()
26+
videos = pipeline(
27+
prompt=[config.prompt]*jax.device_count(),
28+
negative_prompt=[config.negative_prompt]*jax.device_count(),
29+
height=config.height,
30+
width=config.width,
31+
num_frames=config.num_frames,
32+
num_inference_steps=config.num_inference_steps,
33+
guidance_scale=config.guidance_scale,
34+
)
35+
36+
print("compile time: ", (time.perf_counter() - s0))
37+
for i in range(len(videos)):
38+
export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16)
39+
s0 = time.perf_counter()
40+
with jax.profiler.trace("/tmp/trace/"):
41+
videos = pipeline(
42+
prompt=[config.prompt]*jax.device_count(),
43+
negative_prompt=[config.negative_prompt]*jax.device_count(),
44+
height=config.height,
45+
width=config.width,
46+
num_frames=config.num_frames,
47+
num_inference_steps=config.num_inference_steps,
48+
guidance_scale=config.guidance_scale,
49+
)
50+
print("generation time: ", (time.perf_counter() - s0))
51+
for i in range(len(videos)):
52+
export_to_video(videos[i], f"wan_output_{i}.mp4", fps=16)
53+
54+
55+
def main(argv: Sequence[str]) -> None:
56+
pyconfig.initialize(argv)
57+
run(pyconfig.config)
58+
59+
60+
if __name__ == "__main__":
61+
app.run(main)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch
3+
from diffusers.utils import export_to_video
4+
from diffusers import AutoencoderKLWan, WanPipeline
5+
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
6+
# import torch_xla.core.xla_model as xm
7+
# import torch_xla.runtime as xr
8+
9+
import numpy as np
10+
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
11+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
12+
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
13+
flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
14+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
15+
16+
17+
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
18+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
19+
20+
21+
output = pipe(
22+
prompt=prompt,
23+
negative_prompt=negative_prompt,
24+
height=480,
25+
width=720,
26+
num_frames=21,
27+
guidance_scale=5.0,
28+
num_inference_steps=10,
29+
).frames[0]
30+
31+
32+
export_to_video(output, "output.mp4", fps=16)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,9 @@ def __call__(
404404
num_frames=num_frames,
405405
num_channels_latents=num_channel_latents
406406
)
407+
408+
# import pdb
409+
# pdb.set_trace()
407410

408411
data_sharding = PositionalSharding(self.devices_array).replicate()
409412
if len(prompt) % jax.device_count() == 0:

0 commit comments

Comments
 (0)