diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 9ef1b72d6..21796738a 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -29,6 +29,9 @@ log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + unet_checkpoint: '' revision: '' # This will convert the weights to this dtype. diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index f84346735..9c9ae2c67 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,4 +1,5 @@ import json +import torch import jax import jax.numpy as jnp from maxdiffusion import max_logging @@ -7,6 +8,8 @@ from flax.traverse_util import unflatten_dict, flatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) +CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" + def _tuple_str_to_int(in_tuple): out_list = [] @@ -25,7 +28,68 @@ def rename_for_nnx(key): return new_key +def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + device = jax.devices(device)[0] + with jax.default_device(device): + if hf_download: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt") + loaded_state_dict = torch.load(ckpt_shard_path) + + tensors = {} + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + flattened_dict = flatten_dict(eval_shapes) + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = {} + for key in flattened_dict: + string_tuple = tuple([str(item) for item in key]) + random_flax_state_dict[string_tuple] = flattened_dict[key] + for pt_key, tensor in loaded_state_dict.items(): + tensor = torch2jax(tensor) + renamed_pt_key = rename_key(pt_key) + renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj") + + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1") + renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2") + renamed_pt_key = renamed_pt_key.replace(".q.", ".query.") + renamed_pt_key = renamed_pt_key.replace(".k.", ".key.") + renamed_pt_key = renamed_pt_key.replace(".v.", ".value.") + renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.") + renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj") + renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm") + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict + + def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): + + if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: + return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + else: + return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + + +def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] with jax.default_device(device): if hf_download: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index db4b25fb2..d01aea3fc 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -95,7 +95,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # 4. Load pretrained weights and move them to device using the state shardings from (3) above. # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") + params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 67437ba0b..dbc48cc54 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -25,6 +25,7 @@ import yaml from . import max_logging from . import max_utils +from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH def string_to_bool(s: str) -> bool: @@ -102,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs): jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"]) _HyperParameters.user_init(raw_keys) + _HyperParameters.wan_init(raw_keys) self.keys = raw_keys for k in sorted(raw_keys.keys()): max_logging.log(f"Config param {k}: {raw_keys[k]}") @@ -110,6 +112,23 @@ def _load_kwargs(self, argv: list[str]): args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict + @staticmethod + def wan_init(raw_keys): + if "wan_transformer_pretrained_model_name_or_path" in raw_keys: + transformer_pretrained_model_name_or_path = raw_keys["wan_transformer_pretrained_model_name_or_path"] + if transformer_pretrained_model_name_or_path == "": + raw_keys["wan_transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: + # Set correct parameters for CausVid in case of user error. + raw_keys["guidance_scale"] = 1.0 + num_inference_steps = raw_keys["num_inference_steps"] + if num_inference_steps > 10: + max_logging.log( + f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps." + ) + else: + raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1") + @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime"""