-
Notifications
You must be signed in to change notification settings - Fork 70
Expand file tree
/
Copy pathwan_utils.py
More file actions
203 lines (179 loc) · 9.87 KB
/
wan_utils.py
File metadata and controls
203 lines (179 loc) · 9.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import json
import torch
import jax
import jax.numpy as jnp
from maxdiffusion import max_logging
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from flax.traverse_util import unflatten_dict, flatten_dict
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
def _tuple_str_to_int(in_tuple):
out_list = []
for item in in_tuple:
try:
out_list.append(int(item))
except ValueError:
out_list.append(item)
return tuple(out_list)
def rename_for_nnx(key):
new_key = key
if "norm_k" in key or "norm_q" in key:
new_key = key[:-1] + ("scale",)
return new_key
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
with jax.default_device(device):
if hf_download:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
loaded_state_dict = torch.load(ckpt_shard_path)
tensors = {}
flax_state_dict = {}
cpu = jax.local_devices(backend="cpu")[0]
flattened_dict = flatten_dict(eval_shapes)
# turn all block numbers to strings just for matching weights.
# Later they will be turned back to ints.
random_flax_state_dict = {}
for key in flattened_dict:
string_tuple = tuple([str(item) for item in key])
random_flax_state_dict[string_tuple] = flattened_dict[key]
for pt_key, tensor in loaded_state_dict.items():
tensor = torch2jax(tensor)
renamed_pt_key = rename_key(pt_key)
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
pt_tuple_key = tuple(renamed_pt_key.split("."))
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
jax.clear_caches()
return flax_state_dict
def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
else:
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
with jax.default_device(device):
if hf_download:
# download the index file for sharded models.
index_file_path = hf_hub_download(
pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json"
)
# open the index file.
with open(index_file_path, "r") as f:
index_dict = json.load(f)
model_files = set()
for key in index_dict["weight_map"].keys():
model_files.add(index_dict["weight_map"][key])
model_files = list(model_files)
tensors = {}
for model_file in model_files:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file)
# now get all the filenames for the model that need downloading
max_logging.log(f"Load and port Wan 2.1 transformer on {device}")
if ckpt_shard_path is not None:
with safe_open(ckpt_shard_path, framework="pt") as f:
for k in f.keys():
tensors[k] = torch2jax(f.get_tensor(k))
flax_state_dict = {}
cpu = jax.local_devices(backend="cpu")[0]
flattened_dict = flatten_dict(eval_shapes)
# turn all block numbers to strings just for matching weights.
# Later they will be turned back to ints.
random_flax_state_dict = {}
for key in flattened_dict:
string_tuple = tuple([str(item) for item in key])
random_flax_state_dict[string_tuple] = flattened_dict[key]
del flattened_dict
for pt_key, tensor in tensors.items():
renamed_pt_key = rename_key(pt_key)
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
pt_tuple_key = tuple(renamed_pt_key.split("."))
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
jax.clear_caches()
return flax_state_dict
def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
with jax.default_device(device):
if hf_download:
ckpt_path = hf_hub_download(
pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors"
)
max_logging.log(f"Load and port Wan 2.1 VAE on {device}")
if ckpt_path is not None:
tensors = {}
with safe_open(ckpt_path, framework="pt") as f:
for k in f.keys():
tensors[k] = torch2jax(f.get_tensor(k))
flax_state_dict = {}
cpu = jax.local_devices(backend="cpu")[0]
for pt_key, tensor in tensors.items():
renamed_pt_key = rename_key(pt_key)
# Order matters
renamed_pt_key = renamed_pt_key.replace("up_blocks_", "up_blocks.")
renamed_pt_key = renamed_pt_key.replace("mid_block_", "mid_block.")
renamed_pt_key = renamed_pt_key.replace("down_blocks_", "down_blocks.")
renamed_pt_key = renamed_pt_key.replace("conv_in.bias", "conv_in.conv.bias")
renamed_pt_key = renamed_pt_key.replace("conv_in.weight", "conv_in.conv.weight")
renamed_pt_key = renamed_pt_key.replace("conv_out.bias", "conv_out.conv.bias")
renamed_pt_key = renamed_pt_key.replace("conv_out.weight", "conv_out.conv.weight")
renamed_pt_key = renamed_pt_key.replace("attentions_", "attentions.")
renamed_pt_key = renamed_pt_key.replace("resnets_", "resnets.")
renamed_pt_key = renamed_pt_key.replace("upsamplers_", "upsamplers.")
renamed_pt_key = renamed_pt_key.replace("resample_", "resample.")
renamed_pt_key = renamed_pt_key.replace("conv1.bias", "conv1.conv.bias")
renamed_pt_key = renamed_pt_key.replace("conv1.weight", "conv1.conv.weight")
renamed_pt_key = renamed_pt_key.replace("conv2.bias", "conv2.conv.bias")
renamed_pt_key = renamed_pt_key.replace("conv2.weight", "conv2.conv.weight")
renamed_pt_key = renamed_pt_key.replace("time_conv.bias", "time_conv.conv.bias")
renamed_pt_key = renamed_pt_key.replace("time_conv.weight", "time_conv.conv.weight")
renamed_pt_key = renamed_pt_key.replace("quant_conv", "quant_conv.conv")
renamed_pt_key = renamed_pt_key.replace("conv_shortcut", "conv_shortcut.conv")
if "decoder" in renamed_pt_key:
renamed_pt_key = renamed_pt_key.replace("resample.1.bias", "resample.layers.1.bias")
renamed_pt_key = renamed_pt_key.replace("resample.1.weight", "resample.layers.1.weight")
if "encoder" in renamed_pt_key:
renamed_pt_key = renamed_pt_key.replace("resample.1", "resample.conv")
pt_tuple_key = tuple(renamed_pt_key.split("."))
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes)
flax_key = _tuple_str_to_int(flax_key)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
jax.clear_caches()
else:
raise FileNotFoundError(f"Path {ckpt_path} was not found")
return flax_state_dict