Skip to content

Commit 0ef8c71

Browse files
load wan 2.1 transformer weights.
1 parent 440f39c commit 0ef8c71

3 files changed

Lines changed: 69 additions & 12 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ gcs_metrics: False
2727
save_config_to_gcs: False
2828
log_period: 100
2929

30-
pretrained_model_name_or_path: ''
30+
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
3131

3232
unet_checkpoint: ''
3333
revision: ''

src/maxdiffusion/models/attention_flax.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,6 @@ def __init__(
618618
in_features=self.inner_dim,
619619
out_features=self.inner_dim,
620620
kernel_init=qkv_init_kernel,
621-
use_bias=qkv_bias,
622621
dtype=dtype,
623622
param_dtype=weights_dtype,
624623
precision=precision,
@@ -629,7 +628,6 @@ def __init__(
629628
in_features=self.inner_dim,
630629
out_features=self.inner_dim,
631630
kernel_init=qkv_init_kernel,
632-
use_bias=qkv_bias,
633631
dtype=dtype,
634632
param_dtype=weights_dtype,
635633
precision=precision,
@@ -640,7 +638,6 @@ def __init__(
640638
in_features=self.inner_dim,
641639
out_features=self.inner_dim,
642640
kernel_init=qkv_init_kernel,
643-
use_bias=qkv_bias,
644641
dtype=dtype,
645642
param_dtype=weights_dtype,
646643
precision=precision,
@@ -651,16 +648,15 @@ def __init__(
651648
in_features=self.inner_dim,
652649
out_features=self.inner_dim,
653650
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
654-
use_bias=qkv_bias,
655651
dtype=dtype,
656652
param_dtype=weights_dtype,
657653
precision=precision,
658654
)
659655

660-
self.query_norm = None
661-
self.key_norm = None
656+
self.norm_q = None
657+
self.norm_k = None
662658
if qk_norm is not None:
663-
self.query_norm = nnx.RMSNorm(
659+
self.norm_q = nnx.RMSNorm(
664660
num_features=self.inner_dim,
665661
rngs=rngs,
666662
epsilon=eps,
@@ -669,7 +665,7 @@ def __init__(
669665
param_dtype=weights_dtype
670666
)
671667

672-
self.key_norm = nnx.RMSNorm(
668+
self.norm_k = nnx.RMSNorm(
673669
num_features=self.inner_dim,
674670
rngs=rngs,
675671
dtype=dtype,
@@ -713,8 +709,8 @@ def __call__(
713709
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
714710

715711
if self.qk_norm:
716-
query_proj = self.query_norm(query_proj)
717-
key_proj = self.key_norm(key_proj)
712+
query_proj = self.norm_q(query_proj)
713+
key_proj = self.norm_k(key_proj)
718714

719715
if rotary_emb is not None:
720716
query_proj = _unflatten_heads(query_proj, self.heads)

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import json
12
import jax
23
import jax.numpy as jnp
34
from maxdiffusion import max_logging
45
from huggingface_hub import hf_hub_download
56
from safetensors import safe_open
6-
from flax.traverse_util import unflatten_dict
7+
from flax.traverse_util import unflatten_dict, flatten_dict
78
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
89

910

@@ -16,6 +17,66 @@ def _tuple_str_to_int(in_tuple):
1617
out_list.append(item)
1718
return tuple(out_list)
1819

20+
def rename_for_nnx(key):
21+
new_key = key
22+
if "norm_k" in key or "norm_q" in key:
23+
new_key = key[:-1] + ("scale",)
24+
return new_key
25+
26+
def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
27+
device = jax.devices(device)[0]
28+
with jax.default_device(device):
29+
if hf_download:
30+
# download the index file for sharded models.
31+
index_file_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json")
32+
# open the index file.
33+
with open(index_file_path, 'r') as f:
34+
index_dict = json.load(f)
35+
model_files = set()
36+
for key in index_dict["weight_map"].keys():
37+
model_files.add(index_dict["weight_map"][key])
38+
39+
model_files = list(model_files)
40+
tensors = {}
41+
for model_file in model_files:
42+
ckpt_shard_path = hf_hub_download(
43+
pretrained_model_name_or_path, subfolder="transformer", filename=model_file
44+
)
45+
# now get all the filenames for the model that need downloading
46+
max_logging.log(f"Load and port Wan 2.1 transformer on {device}")
47+
48+
if ckpt_shard_path is not None:
49+
with safe_open(ckpt_shard_path, framework="pt") as f:
50+
for k in f.keys():
51+
tensors[k] = torch2jax(f.get_tensor(k))
52+
flax_state_dict = {}
53+
cpu = jax.local_devices(backend="cpu")[0]
54+
flattened_dict = flatten_dict(eval_shapes)
55+
# turn all block numbers to strings just for matching weights.
56+
# Later they will be turned back to ints.
57+
random_flax_state_dict = {}
58+
for key in flattened_dict:
59+
string_tuple = tuple([str(item) for item in key])
60+
random_flax_state_dict[string_tuple] = flattened_dict[key]
61+
del flattened_dict
62+
for pt_key, tensor in tensors.items():
63+
renamed_pt_key = rename_key(pt_key)
64+
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
65+
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
66+
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
67+
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
68+
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
69+
pt_tuple_key = tuple(renamed_pt_key.split("."))
70+
71+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
72+
flax_key = rename_for_nnx(flax_key)
73+
flax_key = _tuple_str_to_int(flax_key)
74+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
75+
validate_flax_state_dict(eval_shapes, flax_state_dict)
76+
flax_state_dict = unflatten_dict(flax_state_dict)
77+
del tensors
78+
jax.clear_caches()
79+
return flax_state_dict
1980

2081
def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
2182
device = jax.devices(device)[0]

0 commit comments

Comments
 (0)