Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ log_period: 100

pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'

# Overrides the transformer from pretrained_model_name_or_path
wan_transformer_pretrained_model_name_or_path: ''

unet_checkpoint: ''
revision: ''
# This will convert the weights to this dtype.
Expand Down
64 changes: 64 additions & 0 deletions src/maxdiffusion/models/wan/wan_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import torch
import jax
import jax.numpy as jnp
from maxdiffusion import max_logging
Expand All @@ -7,6 +8,8 @@
from flax.traverse_util import unflatten_dict, flatten_dict
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)

CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"


def _tuple_str_to_int(in_tuple):
out_list = []
Expand All @@ -25,7 +28,68 @@ def rename_for_nnx(key):
return new_key


def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
with jax.default_device(device):
if hf_download:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
loaded_state_dict = torch.load(ckpt_shard_path)

tensors = {}
flax_state_dict = {}
cpu = jax.local_devices(backend="cpu")[0]
flattened_dict = flatten_dict(eval_shapes)
# turn all block numbers to strings just for matching weights.
# Later they will be turned back to ints.
random_flax_state_dict = {}
for key in flattened_dict:
string_tuple = tuple([str(item) for item in key])
random_flax_state_dict[string_tuple] = flattened_dict[key]
for pt_key, tensor in loaded_state_dict.items():
tensor = torch2jax(tensor)
renamed_pt_key = rename_key(pt_key)
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")

renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")

pt_tuple_key = tuple(renamed_pt_key.split("."))

flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
jax.clear_caches()
return flax_state_dict


def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):

if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
else:
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)


def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
device = jax.devices(device)[0]
with jax.default_device(device):
if hf_download:
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
# This helps with loading sharded weights directly into the accelerators without fist copying them
# all to one device and then distributing them, thus using low HBM memory.
params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu")
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding[path].value
Expand Down
19 changes: 19 additions & 0 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import yaml
from . import max_logging
from . import max_utils
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH


def string_to_bool(s: str) -> bool:
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs):
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])

_HyperParameters.user_init(raw_keys)
_HyperParameters.wan_init(raw_keys)
self.keys = raw_keys
for k in sorted(raw_keys.keys()):
max_logging.log(f"Config param {k}: {raw_keys[k]}")
Expand All @@ -110,6 +112,23 @@ def _load_kwargs(self, argv: list[str]):
args_dict = dict(a.split("=", 1) for a in argv[2:])
return args_dict

@staticmethod
def wan_init(raw_keys):
if "wan_transformer_pretrained_model_name_or_path" in raw_keys:
transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"]
if transformer_pretrained_model_name_or_path == "":
raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
# Set correct parameters for CausVid in case of user error.
raw_keys["guidance_scale"] = 1.0
num_inference_steps = raw_keys["num_inference_steps"]
if num_inference_steps > 10:
max_logging.log(
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
)
else:
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")

@staticmethod
def user_init(raw_keys):
"""Transformations between the config data and configs used at runtime"""
Expand Down
Loading