Skip to content

Commit b9019f8

Browse files
author
Juan Acevedo
committed
adds causvid
1 parent 2016d7b commit b9019f8

4 files changed

Lines changed: 30 additions & 10 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ save_config_to_gcs: False
2828
log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
31+
3132
# Overrides the transformer from pretrained_model_name_or_path
32-
transformer_pretrained_model_name_or_path: 'lightx2v/Wan2.1-T2V-14B-CausVid'
33+
transformer_pretrained_model_name_or_path: ''
3334

3435
unet_checkpoint: ''
3536
revision: ''

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from flax.traverse_util import unflatten_dict, flatten_dict
99
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
1010

11+
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
12+
1113

1214
def _tuple_str_to_int(in_tuple):
1315
out_list = []
@@ -25,13 +27,14 @@ def rename_for_nnx(key):
2527
new_key = key[:-1] + ("scale",)
2628
return new_key
2729

30+
2831
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
2932
device = jax.devices(device)[0]
3033
with jax.default_device(device):
3134
if hf_download:
3235
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
3336
loaded_state_dict = torch.load(ckpt_shard_path)
34-
37+
3538
tensors = {}
3639
flax_state_dict = {}
3740
cpu = jax.local_devices(backend="cpu")[0]
@@ -77,13 +80,15 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
7780
jax.clear_caches()
7881
return flax_state_dict
7982

83+
8084
def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
81-
82-
if "CausVid" in pretrained_model_name_or_path:
85+
86+
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
8387
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
8488
else:
8589
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
8690

91+
8792
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
8893
device = jax.devices(device)[0]
8994
with jax.default_device(device):

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
7171
return wan_transformer
7272

7373
# 1. Load config.
74-
wan_config = WanModel.load_config(
75-
config.pretrained_model_name_or_path,
76-
subfolder="transformer")
74+
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
7775
wan_config["mesh"] = mesh
7876
wan_config["dtype"] = config.activations_dtype
7977
wan_config["weights_dtype"] = config.weights_dtype
@@ -97,9 +95,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
9795
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
9896
# This helps with loading sharded weights directly into the accelerators without fist copying them
9997
# all to one device and then distributing them, thus using low HBM memory.
100-
params = load_wan_transformer(
101-
config.transformer_pretrained_model_name_or_path or config.pretrained_model_name_or_path,
102-
params, "cpu")
98+
params = load_wan_transformer(config.transformer_pretrained_model_name_or_path, params, "cpu")
10399
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
104100
for path, val in flax.traverse_util.flatten_dict(params).items():
105101
sharding = logical_state_sharding[path].value

src/maxdiffusion/pyconfig.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import yaml
2626
from . import max_logging
2727
from . import max_utils
28+
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH
2829

2930

3031
def string_to_bool(s: str) -> bool:
@@ -102,6 +103,7 @@ def __init__(self, argv: list[str], **kwargs):
102103
jax.config.update("jax_compilation_cache_dir", raw_keys["jax_cache_dir"])
103104

104105
_HyperParameters.user_init(raw_keys)
106+
_HyperParameters.wan_init(raw_keys)
105107
self.keys = raw_keys
106108
for k in sorted(raw_keys.keys()):
107109
max_logging.log(f"Config param {k}: {raw_keys[k]}")
@@ -110,6 +112,22 @@ def _load_kwargs(self, argv: list[str]):
110112
args_dict = dict(a.split("=", 1) for a in argv[2:])
111113
return args_dict
112114

115+
@staticmethod
116+
def wan_init(raw_keys):
117+
transformer_pretrained_model_name_or_path = raw_keys["transformer_pretrained_model_name_or_path"]
118+
if transformer_pretrained_model_name_or_path == "":
119+
raw_keys["transformer_pretrained_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
120+
elif transformer_pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
121+
# Set correct parameters for CausVid in case of user error.
122+
raw_keys["guidance_scale"] = 1.0
123+
num_inference_steps = raw_keys["num_inference_steps"]
124+
if num_inference_steps > 10:
125+
max_logging.log(
126+
f"Warning: Try setting num_inference_steps to less than 8 steps when using CausVid, currently you are setting {num_inference_steps} steps."
127+
)
128+
else:
129+
raise ValueError(f"{transformer_pretrained_model_name_or_path} transformer model is not supported for Wan 2.1")
130+
113131
@staticmethod
114132
def user_init(raw_keys):
115133
"""Transformations between the config data and configs used at runtime"""

0 commit comments

Comments
 (0)