11import json
2+ import torch
23import jax
34import jax .numpy as jnp
45from maxdiffusion import max_logging
@@ -24,8 +25,66 @@ def rename_for_nnx(key):
2425 new_key = key [:- 1 ] + ("scale" ,)
2526 return new_key
2627
28+ def load_causvid_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
29+ device = jax .devices (device )[0 ]
30+ with jax .default_device (device ):
31+ if hf_download :
32+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , filename = "causal_model.pt" )
33+ loaded_state_dict = torch .load (ckpt_shard_path )
34+
35+ tensors = {}
36+ flax_state_dict = {}
37+ cpu = jax .local_devices (backend = "cpu" )[0 ]
38+ flattened_dict = flatten_dict (eval_shapes )
39+ # turn all block numbers to strings just for matching weights.
40+ # Later they will be turned back to ints.
41+ random_flax_state_dict = {}
42+ for key in flattened_dict :
43+ string_tuple = tuple ([str (item ) for item in key ])
44+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
45+ for pt_key , tensor in loaded_state_dict .items ():
46+ tensor = torch2jax (tensor )
47+ renamed_pt_key = rename_key (pt_key )
48+ renamed_pt_key = renamed_pt_key .replace ("head.modulation" , "scale_shift_table" )
49+ renamed_pt_key = renamed_pt_key .replace ("head.head" , "proj_out" )
50+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_0" , "condition_embedder.text_embedder.linear_1" )
51+ renamed_pt_key = renamed_pt_key .replace ("text_embedding_2" , "condition_embedder.text_embedder.linear_2" )
52+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_0" , "condition_embedder.time_embedder.linear_1" )
53+ renamed_pt_key = renamed_pt_key .replace ("time_embedding_2" , "condition_embedder.time_embedder.linear_2" )
54+ renamed_pt_key = renamed_pt_key .replace ("time_projection_1" , "condition_embedder.time_proj" )
55+
56+ renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
57+ renamed_pt_key = renamed_pt_key .replace ("self_attn" , "attn1" )
58+ renamed_pt_key = renamed_pt_key .replace ("cross_attn" , "attn2" )
59+ renamed_pt_key = renamed_pt_key .replace (".q." , ".query." )
60+ renamed_pt_key = renamed_pt_key .replace (".k." , ".key." )
61+ renamed_pt_key = renamed_pt_key .replace (".v." , ".value." )
62+ renamed_pt_key = renamed_pt_key .replace (".o." , ".proj_attn." )
63+ renamed_pt_key = renamed_pt_key .replace ("ffn_0" , "ffn.act_fn.proj" )
64+ renamed_pt_key = renamed_pt_key .replace ("ffn_2" , "ffn.proj_out" )
65+ renamed_pt_key = renamed_pt_key .replace (".modulation" , ".scale_shift_table" )
66+ renamed_pt_key = renamed_pt_key .replace ("norm3" , "norm2.layer_norm" )
67+
68+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
69+
70+ flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
71+ flax_key = rename_for_nnx (flax_key )
72+ flax_key = _tuple_str_to_int (flax_key )
73+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
74+ validate_flax_state_dict (eval_shapes , flax_state_dict )
75+ flax_state_dict = unflatten_dict (flax_state_dict )
76+ del tensors
77+ jax .clear_caches ()
78+ return flax_state_dict
2779
2880def load_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
81+
82+ if "CausVid" in pretrained_model_name_or_path :
83+ return load_causvid_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
84+ else :
85+ return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
86+
87+ def load_base_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
2988 device = jax .devices (device )[0 ]
3089 with jax .default_device (device ):
3190 if hf_download :
0 commit comments