Skip to content

Commit 5cc2e49

Browse files
implements a working wan 2.1 pipeline.
1 parent 2388908 commit 5cc2e49

2 files changed

Lines changed: 16 additions & 20 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from typing import Sequence
16+
import jax
1617
import time
1718
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
1819
from maxdiffusion import pyconfig
@@ -21,7 +22,6 @@
2122

2223
def run(config):
2324
pipeline = WanPipeline.from_pretrained(config)
24-
2525
s0 = time.perf_counter()
2626
video = pipeline(
2727
prompt=config.prompt,
@@ -32,17 +32,20 @@ def run(config):
3232
num_inference_steps=config.num_inference_steps,
3333
guidance_scale=config.guidance_scale,
3434
)
35+
3536
print("compile time: ", (time.perf_counter() - s0))
37+
export_to_video(video[0], "jax_output.mp4", fps=16)
3638
s0 = time.perf_counter()
37-
video = pipeline(
38-
prompt=config.prompt,
39-
negative_prompt=config.negative_prompt,
40-
height=config.height,
41-
width=config.width,
42-
num_frames=config.num_frames,
43-
num_inference_steps=config.num_inference_steps,
44-
guidance_scale=config.guidance_scale,
45-
)
39+
with jax.profiler.trace("/tmp/trace/"):
40+
video = pipeline(
41+
prompt=config.prompt,
42+
negative_prompt=config.negative_prompt,
43+
height=config.height,
44+
width=config.width,
45+
num_frames=config.num_frames,
46+
num_inference_steps=config.num_inference_steps,
47+
guidance_scale=config.guidance_scale,
48+
)
4649
print("generation time: ", (time.perf_counter() - s0))
4750
export_to_video(video[0], "jax_output.mp4", fps=16)
4851

@@ -51,5 +54,6 @@ def main(argv: Sequence[str]) -> None:
5154
pyconfig.initialize(argv)
5255
run(pyconfig.config)
5356

57+
5458
if __name__ == "__main__":
5559
app.run(main)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,6 @@ def __init__(
109109
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
110110
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
111111

112-
self.jitted_decode = jax.jit(
113-
partial(
114-
self.vae.decode,
115-
feat_cache=self.vae_cache,
116-
return_dict=False
117-
)
118-
)
119-
120112
self.p_run_inference = None
121113

122114
@classmethod
@@ -402,8 +394,8 @@ def __call__(
402394
latents = latents / latents_std + latents_mean
403395
latents = latents.astype(self.config.weights_dtype)
404396

405-
with self.mesh:
406-
video = self.jitted_decode(latents)[0]
397+
video = self.vae.decode(latents, self.vae_cache)[0]
398+
407399
video = jnp.transpose(video, (0, 4, 1, 2, 3))
408400
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
409401
video = self.video_processor.postprocess_video(video, output_type="np")

0 commit comments

Comments
 (0)