1+ import json
12import jax
23import jax .numpy as jnp
34from maxdiffusion import max_logging
45from huggingface_hub import hf_hub_download
56from safetensors import safe_open
6- from flax .traverse_util import unflatten_dict
7+ from flax .traverse_util import unflatten_dict , flatten_dict
78from ..modeling_flax_pytorch_utils import (rename_key , rename_key_and_reshape_tensor , torch2jax , validate_flax_state_dict )
89
910
@@ -16,6 +17,66 @@ def _tuple_str_to_int(in_tuple):
1617 out_list .append (item )
1718 return tuple (out_list )
1819
20+ def rename_for_nnx (key ):
21+ new_key = key
22+ if "norm_k" in key or "norm_q" in key :
23+ new_key = key [:- 1 ] + ("scale" ,)
24+ return new_key
25+
26+ def load_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
27+ device = jax .devices (device )[0 ]
28+ with jax .default_device (device ):
29+ if hf_download :
30+ # download the index file for sharded models.
31+ index_file_path = hf_hub_download (pretrained_model_name_or_path , subfolder = "transformer" , filename = "diffusion_pytorch_model.safetensors.index.json" )
32+ # open the index file.
33+ with open (index_file_path , 'r' ) as f :
34+ index_dict = json .load (f )
35+ model_files = set ()
36+ for key in index_dict ["weight_map" ].keys ():
37+ model_files .add (index_dict ["weight_map" ][key ])
38+
39+ model_files = list (model_files )
40+ tensors = {}
41+ for model_file in model_files :
42+ ckpt_shard_path = hf_hub_download (
43+ pretrained_model_name_or_path , subfolder = "transformer" , filename = model_file
44+ )
45+ # now get all the filenames for the model that need downloading
46+ max_logging .log (f"Load and port Wan 2.1 transformer on { device } " )
47+
48+ if ckpt_shard_path is not None :
49+ with safe_open (ckpt_shard_path , framework = "pt" ) as f :
50+ for k in f .keys ():
51+ tensors [k ] = torch2jax (f .get_tensor (k ))
52+ flax_state_dict = {}
53+ cpu = jax .local_devices (backend = "cpu" )[0 ]
54+ flattened_dict = flatten_dict (eval_shapes )
55+ # turn all block numbers to strings just for matching weights.
56+ # Later they will be turned back to ints.
57+ random_flax_state_dict = {}
58+ for key in flattened_dict :
59+ string_tuple = tuple ([str (item ) for item in key ])
60+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
61+ del flattened_dict
62+ for pt_key , tensor in tensors .items ():
63+ renamed_pt_key = rename_key (pt_key )
64+ renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
65+ renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
66+ renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
67+ renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
68+ renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
69+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
70+
71+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
72+ flax_key = rename_for_nnx (flax_key )
73+ flax_key = _tuple_str_to_int (flax_key )
74+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
75+ validate_flax_state_dict (eval_shapes , flax_state_dict )
76+ flax_state_dict = unflatten_dict (flax_state_dict )
77+ del tensors
78+ jax .clear_caches ()
79+ return flax_state_dict
1980
2081def load_wan_vae (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
2182 device = jax .devices (device )[0 ]
0 commit comments