Skip to content

Commit 31dcb6c

Browse files
entrpnJuan Acevedo
andauthored
Wan caus vid (#196)
Adds caus_vid model for faster inference. --------- Co-authored-by: Juan Acevedo <juancevedo@google.com>
1 parent b4b5a45 commit 31dcb6c

4 files changed

Lines changed: 87 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
3131

32+
# Overrides the transformer from pretrained_model_name_or_path
33+
wan_transformer_pretrained_model_name_or_path: ''
34+
3235
unet_checkpoint: ''
3336
revision: ''
3437
# This will convert the weights to this dtype.

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import torch
23
import jax
34
import jax.numpy as jnp
45
from maxdiffusion import max_logging
@@ -7,6 +8,8 @@
78
from flax.traverse_util import unflatten_dict, flatten_dict
89
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
910

11+
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
12+
1013

1114
def _tuple_str_to_int(in_tuple):
1215
out_list = []
@@ -25,7 +28,68 @@ def rename_for_nnx(key):
2528
return new_key
2629

2730

31+
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
32+
device = jax.devices(device)[0]
33+
with jax.default_device(device):
34+
if hf_download:
35+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
36+
loaded_state_dict = torch.load(ckpt_shard_path)
37+
38+
tensors = {}
39+
flax_state_dict = {}
40+
cpu = jax.local_devices(backend="cpu")[0]
41+
flattened_dict = flatten_dict(eval_shapes)
42+
# turn all block numbers to strings just for matching weights.
43+
# Later they will be turned back to ints.
44+
random_flax_state_dict = {}
45+
for key in flattened_dict:
46+
string_tuple = tuple([str(item) for item in key])
47+
random_flax_state_dict[string_tuple] = flattened_dict[key]
48+
for pt_key, tensor in loaded_state_dict.items():
49+
tensor = torch2jax(tensor)
50+
renamed_pt_key = rename_key(pt_key)
51+
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
52+
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
53+
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
54+
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
55+
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
56+
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
57+
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")
58+
59+
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
60+
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
61+
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
62+
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
63+
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
64+
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
65+
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
66+
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
67+
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
68+
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
69+
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
70+
71+
pt_tuple_key = tuple(renamed_pt_key.split("."))
72+
73+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
74+
flax_key = rename_for_nnx(flax_key)
75+
flax_key = _tuple_str_to_int(flax_key)
76+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
77+
validate_flax_state_dict(eval_shapes, flax_state_dict)
78+
flax_state_dict = unflatten_dict(flax_state_dict)
79+
del tensors
80+
jax.clear_caches()
81+
return flax_state_dict
82+
83+
2884
def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
85+
86+
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
87+
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
88+
else:
89+
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
90+
91+
92+
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
2993
device = jax.devices(device)[0]
3094
with jax.default_device(device):
3195
if hf_download:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
9595
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
9696
# This helps with loading sharded weights directly into the accelerators without fist copying them
9797
# all to one device and then distributing them, thus using low HBM memory.
98-
params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
98+
params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu")
9999
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
100100
for path, val in flax.traverse_util.flatten_dict(params).items():
101101
sharding = logical_state_sharding[path].value

src/maxdiffusion/pyconfig.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28+
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
2829

2930

3031
def string_to_bool(s: str) -> bool:
@@ -102,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs):
102103
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])
103104

104105
_HyperParameters.user_init(raw_keys)
106+
_HyperParameters.wan_init(raw_keys)
105107
self.keys = raw_keys
106108
for k in sorted(raw_keys.keys()):
107109
max_logging.log(f"Config param {k}: {raw_keys[k]}")
@@ -110,6 +112,23 @@ def _load_kwargs(self, argv: list[str]):
110112
args_dict = dict(a.split("=", 1) for a in argv[2:])
111113
return args_dict
112114

115+
@staticmethod
116+
def wan_init(raw_keys):
117+
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
118+
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
119+
if transformer_pretrained_model_name_or_path == "":
120+
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
121+
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
122+
# Set correct parameters for CausVid in case of user error.
123+
raw_keys["guidance_scale"] = 1.0
124+
num_inference_steps = raw_keys["num_inference_steps"]
125+
if num_inference_steps > 10:
126+
max_logging.log(
127+
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
128+
)
129+
else:
130+
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
131+
113132
@staticmethod
114133
def user_init(raw_keys):
115134
"""Transformations between the config data and configs used at runtime"""

0 commit comments

Comments
 (0)