11import json
2+ import torch
23import jax
34import jax .numpy as jnp
45from maxdiffusion import max_logging
78from flax .traverse_util import unflatten_dict , flatten_dict
89from ..modeling_flax_pytorch_utils import (rename_key , rename_key_and_reshape_tensor , torch2jax , validate_flax_state_dict )
910
11+ CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
12+
1013
1114def _tuple_str_to_int (in_tuple ):
1215 out_list = []
@@ -25,7 +28,68 @@ def rename_for_nnx(key):
2528 return new_key
2629
2730
31+ def load_causvid_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
32+ device = jax .devices (device )[0 ]
33+ with jax .default_device (device ):
34+ if hf_download :
35+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , filename = "causal_model.pt" )
36+ loaded_state_dict = torch .load (ckpt_shard_path )
37+
38+ tensors = {}
39+ flax_state_dict = {}
40+ cpu = jax .local_devices (backend = "cpu" )[0 ]
41+ flattened_dict = flatten_dict (eval_shapes )
42+ # turn all block numbers to strings just for matching weights.
43+ # Later they will be turned back to ints.
44+ random_flax_state_dict = {}
45+ for key in flattened_dict :
46+ string_tuple = tuple ([str (item ) for item in key ])
47+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
48+ for pt_key , tensor in loaded_state_dict .items ():
49+ tensor = torch2jax (tensor )
50+ renamed_pt_key = rename_key (pt_key )
51+ renamed_pt_key = renamed_pt_key .replace ("head.modulation" , "scale_shift_table" )
52+ renamed_pt_key = renamed_pt_key .replace ("head.head" , "proj_out" )
53+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_0" , "condition_embedder.text_embedder.linear_1" )
54+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_2" , "condition_embedder.text_embedder.linear_2" )
55+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_0" , "condition_embedder.time_embedder.linear_1" )
56+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_2" , "condition_embedder.time_embedder.linear_2" )
57+ renamed_pt_key = renamed_pt_key .replace ("time_projection_1" , "condition_embedder.time_proj" )
58+
59+ renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
60+ renamed_pt_key = renamed_pt_key .replace ("self_attn" , "attn1" )
61+ renamed_pt_key = renamed_pt_key .replace ("cross_attn" , "attn2" )
62+ renamed_pt_key = renamed_pt_key .replace (".q." , ".query." )
63+ renamed_pt_key = renamed_pt_key .replace (".k." , ".key." )
64+ renamed_pt_key = renamed_pt_key .replace (".v." , ".value." )
65+ renamed_pt_key = renamed_pt_key .replace (".o." , ".proj_attn." )
66+ renamed_pt_key = renamed_pt_key .replace ("ffn_0" , "ffn.act_fn.proj" )
67+ renamed_pt_key = renamed_pt_key .replace ("ffn_2" , "ffn.proj_out" )
68+ renamed_pt_key = renamed_pt_key .replace (".modulation" , ".scale_shift_table" )
69+ renamed_pt_key = renamed_pt_key .replace ("norm3" , "norm2.layer_norm" )
70+
71+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
72+
73+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
74+ flax_key = rename_for_nnx (flax_key )
75+ flax_key = _tuple_str_to_int (flax_key )
76+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
77+ validate_flax_state_dict (eval_shapes , flax_state_dict )
78+ flax_state_dict = unflatten_dict (flax_state_dict )
79+ del tensors
80+ jax .clear_caches ()
81+ return flax_state_dict
82+
83+
2884def load_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
85+
86+ if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH :
87+ return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
88+ else :
89+ return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
90+
91+
92+ def load_base_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
2993 device = jax .devices (device )[0 ]
3094 with jax .default_device (device ):
3195 if hf_download :
0 commit comments