|
8 | 8 | from safetensors import safe_open |
9 | 9 | from flax.traverse_util import unflatten_dict, flatten_dict |
10 | 10 | from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) |
| 11 | +from ...common_types import WAN_MODEL |
11 | 12 |
|
12 | 13 | CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" |
13 | 14 | WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX" |
@@ -82,7 +83,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di |
82 | 83 |
|
83 | 84 | pt_tuple_key = tuple(renamed_pt_key.split(".")) |
84 | 85 |
|
85 | | - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) |
| 86 | + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) |
86 | 87 | flax_key = rename_for_nnx(flax_key) |
87 | 88 | flax_key = _tuple_str_to_int(flax_key) |
88 | 89 | flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) |
@@ -117,7 +118,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di |
117 | 118 |
|
118 | 119 | pt_tuple_key = tuple(renamed_pt_key.split(".")) |
119 | 120 |
|
120 | | - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) |
| 121 | + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) |
121 | 122 | flax_key = rename_for_nnx(flax_key) |
122 | 123 | flax_key = _tuple_str_to_int(flax_key) |
123 | 124 | flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) |
@@ -196,9 +197,20 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d |
196 | 197 | renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") |
197 | 198 | pt_tuple_key = tuple(renamed_pt_key.split(".")) |
198 | 199 |
|
199 | | - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) |
| 200 | + if "blocks" in pt_tuple_key: |
| 201 | + new_key = ("blocks",) + pt_tuple_key[2:] |
| 202 | + block_index = int(pt_tuple_key[1]) |
| 203 | + pt_tuple_key = new_key |
| 204 | + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) |
200 | 205 | flax_key = rename_for_nnx(flax_key) |
201 | 206 | flax_key = _tuple_str_to_int(flax_key) |
| 207 | + |
| 208 | + if "blocks" in flax_key: |
| 209 | + if flax_key in flax_state_dict: |
| 210 | + new_tensor = flax_state_dict[flax_key] |
| 211 | + else: |
| 212 | + new_tensor = jnp.zeros((40,) + flax_tensor.shape) |
| 213 | + flax_tensor = new_tensor.at[block_index].set(flax_tensor) |
202 | 214 | flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) |
203 | 215 | validate_flax_state_dict(eval_shapes, flax_state_dict) |
204 | 216 | flax_state_dict = unflatten_dict(flax_state_dict) |
|
0 commit comments