Skip to content

Commit 63b5ed7

Browse files
authored
Enable JIT Compilation of WAN VAE Encoder/Decoder Forward Passes (#320)
1 parent ad56886 commit 63b5ed7

7 files changed

Lines changed: 185 additions & 106 deletions

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ gcs_metrics: False
2727
save_config_to_gcs: False
2828
log_period: 100
2929

30-
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-480P-Diffusers'
30+
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-720P-Diffusers'
3131
model_name: wan2.1
3232
model_type: 'I2V'
3333

@@ -280,16 +280,16 @@ prompt: "An astronaut hatching from an egg, on the surface of the moon, the dark
280280
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281281
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
282282
do_classifier_free_guidance: True
283-
height: 480
284-
width: 832
283+
height: 720
284+
width: 1280
285285
num_frames: 81
286286
guidance_scale: 5.0
287-
flow_shift: 3.0
287+
flow_shift: 5.0
288288

289289
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
290290
guidance_rescale: 0.0
291-
num_inference_steps: 30
292-
fps: 24
291+
num_inference_steps: 50
292+
fps: 16
293293
save_final_checkpoint: False
294294

295295
# SDXL Lightning parameters

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,10 @@ prompt: "An astronaut hatching from an egg, on the surface of the moon, the dark
281281
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
282282
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
283283
do_classifier_free_guidance: True
284-
height: 480
285-
width: 832
284+
height: 720
285+
width: 1280
286286
num_frames: 81
287-
flow_shift: 3.0
287+
flow_shift: 5.0
288288

289289
# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
290290
# guidance scale factor for low noise transformer
@@ -300,8 +300,8 @@ boundary_ratio: 0.875
300300

301301
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
302302
guidance_rescale: 0.0
303-
num_inference_steps: 30
304-
fps: 24
303+
num_inference_steps: 50
304+
fps: 16
305305
save_final_checkpoint: False
306306

307307
# SDXL Lightning parameters

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#hardware
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
4+
attention: 'flash'
5+
attention_sharding_uniform: True
46

57
jax_cache_dir: ''
68
weights_dtype: 'bfloat16'

src/maxdiffusion/models/vae_flax.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import flax
2222
import flax.linen as nn
2323
import jax
24+
from jax import tree_util
2425
import jax.numpy as jnp
2526
from flax.core.frozen_dict import FrozenDict
2627

@@ -930,3 +931,30 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r
930931
return (sample,)
931932

932933
return FlaxDecoderOutput(sample=sample)
934+
935+
936+
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
937+
pass
938+
939+
940+
def _wan_diag_gauss_dist_flatten(dist):
941+
return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,)
942+
943+
944+
def _wan_diag_gauss_dist_unflatten(aux, children):
945+
mean, logvar, std, var = children
946+
deterministic = aux[0]
947+
obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution)
948+
obj.mean = mean
949+
obj.logvar = logvar
950+
obj.std = std
951+
obj.var = var
952+
obj.deterministic = deterministic
953+
return obj
954+
955+
956+
tree_util.register_pytree_node(
957+
WanDiagonalGaussianDistribution,
958+
_wan_diag_gauss_dist_flatten,
959+
_wan_diag_gauss_dist_unflatten,
960+
)

0 commit comments

Comments
 (0)