Skip to content

Commit 2016d7b

Browse files
author
Juan Acevedo
committed
use caus_vid for faster inference.
1 parent e53ee2b commit 2016d7b

3 files changed

Lines changed: 67 additions & 2 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ save_config_to_gcs: False
2828
log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
31+
# Overrides the transformer from pretrained_model_name_or_path
32+
transformer_pretrained_model_name_or_path: 'lightx2v/Wan2.1-T2V-14B-CausVid'
3133

3234
unet_checkpoint: ''
3335
revision: ''

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import torch
23
import jax
34
import jax.numpy as jnp
45
from maxdiffusion import max_logging
@@ -24,8 +25,66 @@ def rename_for_nnx(key):
2425
new_key = key[:-1] + ("scale",)
2526
return new_key
2627

28+
def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
29+
device = jax.devices(device)[0]
30+
with jax.default_device(device):
31+
if hf_download:
32+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt")
33+
loaded_state_dict = torch.load(ckpt_shard_path)
34+
35+
tensors = {}
36+
flax_state_dict = {}
37+
cpu = jax.local_devices(backend="cpu")[0]
38+
flattened_dict = flatten_dict(eval_shapes)
39+
# turn all block numbers to strings just for matching weights.
40+
# Later they will be turned back to ints.
41+
random_flax_state_dict = {}
42+
for key in flattened_dict:
43+
string_tuple = tuple([str(item) for item in key])
44+
random_flax_state_dict[string_tuple] = flattened_dict[key]
45+
for pt_key, tensor in loaded_state_dict.items():
46+
tensor = torch2jax(tensor)
47+
renamed_pt_key = rename_key(pt_key)
48+
renamed_pt_key = renamed_pt_key.replace("head.modulation", "scale_shift_table")
49+
renamed_pt_key = renamed_pt_key.replace("head.head", "proj_out")
50+
renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "condition_embedder.text_embedder.linear_1")
51+
renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "condition_embedder.text_embedder.linear_2")
52+
renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "condition_embedder.time_embedder.linear_1")
53+
renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "condition_embedder.time_embedder.linear_2")
54+
renamed_pt_key = renamed_pt_key.replace("time_projection_1", "condition_embedder.time_proj")
55+
56+
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
57+
renamed_pt_key = renamed_pt_key.replace("self_attn", "attn1")
58+
renamed_pt_key = renamed_pt_key.replace("cross_attn", "attn2")
59+
renamed_pt_key = renamed_pt_key.replace(".q.", ".query.")
60+
renamed_pt_key = renamed_pt_key.replace(".k.", ".key.")
61+
renamed_pt_key = renamed_pt_key.replace(".v.", ".value.")
62+
renamed_pt_key = renamed_pt_key.replace(".o.", ".proj_attn.")
63+
renamed_pt_key = renamed_pt_key.replace("ffn_0", "ffn.act_fn.proj")
64+
renamed_pt_key = renamed_pt_key.replace("ffn_2", "ffn.proj_out")
65+
renamed_pt_key = renamed_pt_key.replace(".modulation", ".scale_shift_table")
66+
renamed_pt_key = renamed_pt_key.replace("norm3", "norm2.layer_norm")
67+
68+
pt_tuple_key = tuple(renamed_pt_key.split("."))
69+
70+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
71+
flax_key = rename_for_nnx(flax_key)
72+
flax_key = _tuple_str_to_int(flax_key)
73+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
74+
validate_flax_state_dict(eval_shapes, flax_state_dict)
75+
flax_state_dict = unflatten_dict(flax_state_dict)
76+
del tensors
77+
jax.clear_caches()
78+
return flax_state_dict
2779

2880
def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
81+
82+
if "CausVid" in pretrained_model_name_or_path:
83+
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
84+
else:
85+
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
86+
87+
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
2988
device = jax.devices(device)[0]
3089
with jax.default_device(device):
3190
if hf_download:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
7171
return wan_transformer
7272

7373
# 1. Load config.
74-
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
74+
wan_config = WanModel.load_config(
75+
config.pretrained_model_name_or_path,
76+
subfolder="transformer")
7577
wan_config["mesh"] = mesh
7678
wan_config["dtype"] = config.activations_dtype
7779
wan_config["weights_dtype"] = config.weights_dtype
@@ -95,7 +97,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
9597
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
9698
# This helps with loading sharded weights directly into the accelerators without fist copying them
9799
# all to one device and then distributing them, thus using low HBM memory.
98-
params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
100+
params = load_wan_transformer(
101+
config.transformer_pretrained_model_name_or_path or config.pretrained_model_name_or_path,
102+
params, "cpu")
99103
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
100104
for path, val in flax.traverse_util.flatten_dict(params).items():
101105
sharding = logical_state_sharding[path].value

0 commit comments

Comments
 (0)