Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 67 additions & 19 deletions src/maxdiffusion/models/wan/wan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)

CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"


def _tuple_str_to_int(in_tuple):
Expand All @@ -28,6 +29,69 @@ def rename_for_nnx(key):
return new_key


def rename_for_custom_trasformer(key):
renamed_pt_key = key.replace("model.diffusion_model.", "")

renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")

renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")

return renamed_pt_key


def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
with jax.default_device(device):
if hf_download:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors")
tensors = {}
with safe_open(ckpt_shard_path, framework="pt") as f:
for k in f.keys():
tensors[k] = torch2jax(f.get_tensor(k))

flax_state_dict = {}
cpu = jax.local_devices(backend="cpu")[0]
flattened_dict = flatten_dict(eval_shapes)
# turn all block numbers to strings just for matching weights.
# Later they will be turned back to ints.
random_flax_state_dict = {}
for key in flattened_dict:
string_tuple = tuple([str(item) for item in key])
random_flax_state_dict[string_tuple] = flattened_dict[key]
for pt_key, tensor in tensors.items():
renamed_pt_key = rename_key(pt_key)

renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)

pt_tuple_key = tuple(renamed_pt_key.split("."))

flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
jax.clear_caches()
return flax_state_dict


def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
with jax.default_device(device):
Expand All @@ -48,25 +112,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
for pt_key, tensor in loaded_state_dict.items():
tensor = torch2jax(tensor)
renamed_pt_key = rename_key(pt_key)
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")

renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)

pt_tuple_key = tuple(renamed_pt_key.split("."))

Expand All @@ -85,6 +131,8 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,

if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
else:
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)

Expand Down
7 changes: 5 additions & 2 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import yaml
from . import max_logging
from . import max_utils
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH


def string_to_bool(s: str) -> bool:
Expand Down Expand Up @@ -118,7 +118,10 @@ def wan_init(raw_keys):
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
if transformer_pretrained_model_name_or_path == "":
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
elif (
transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
or transformer_pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH
):
# Set correct parameters for CausVid in case of user error.
raw_keys["guidance_scale"] = 1.0
num_inference_steps = raw_keys["num_inference_steps"]
Expand Down
Loading