Skip to content

Commit 40d423d

Browse files
e2e wan vae with weights loading. Still not fully working.
1 parent 089f8ac commit 40d423d

12 files changed

Lines changed: 527 additions & 57 deletions

src/maxdiffusion/configuration_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ def extract_init_dict(cls, config_dict, **kwargs):
464464
# remove flax internal keys
465465
if hasattr(cls, "_flax_internal_args"):
466466
for arg in cls._flax_internal_args:
467-
expected_keys.remove(arg)
467+
if arg in expected_keys:
468+
expected_keys.remove(arg)
468469

469470
# 2. Remove attributes that cannot be expected from expected config attributes
470471
# remove keys to be ignored

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from absl import app
2727
from transformers import AutoTokenizer, UMT5EncoderModel
2828
from maxdiffusion import pyconfig, max_logging
29+
from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWan
2930
from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel
3031
from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline
3132

src/maxdiffusion/models/flux/util.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from jax import numpy as jnp
1212
from safetensors import safe_open
1313

14-
from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor)
14+
from ..modeling_flax_pytorch_utils import (
15+
rename_key,
16+
rename_key_and_reshape_tensor,
17+
torch2jax
18+
)
1519
from maxdiffusion import max_logging
1620

1721

@@ -32,21 +36,6 @@ class FluxParams:
3236
rngs: Array
3337
param_dtype: DTypeLike
3438

35-
36-
def torch2jax(torch_tensor: torch.Tensor) -> Array:
37-
is_bfloat16 = torch_tensor.dtype == torch.bfloat16
38-
if is_bfloat16:
39-
# upcast the tensor to fp32
40-
torch_tensor = torch_tensor.float()
41-
42-
if torch.device.type != "cpu":
43-
torch_tensor = torch_tensor.to("cpu")
44-
45-
numpy_value = torch_tensor.numpy()
46-
jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None)
47-
return jax_array
48-
49-
5039
@dataclass
5140
class ModelSpec:
5241
params: FluxParams

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,57 @@
1515
""" PyTorch - Flax general utilities."""
1616
import re
1717

18+
import torch
1819
import jax
1920
import jax.numpy as jnp
2021
from flax.linen import Partitioned
2122
from flax.traverse_util import flatten_dict, unflatten_dict
2223
from flax.core.frozen_dict import unfreeze
2324
from jax.random import PRNGKey
24-
25+
from chex import Array
2526
from ..utils import logging
27+
from .. import max_logging
2628

2729

2830
logger = logging.get_logger(__name__)
2931

32+
def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict):
33+
"""
34+
expected_pytree: dict - a pytree that comes from initializing the model.
35+
new_pytree: dict - a pytree that has been created from pytorch weights.
36+
"""
37+
expected_pytree = flatten_dict(expected_pytree)
38+
if len(expected_pytree.keys()) != len(new_pytree.keys()):
39+
set1 = set(expected_pytree.keys())
40+
set2 = set(new_pytree.keys())
41+
missing_keys = set1 ^ set2
42+
max_logging.log(f"missing keys : {missing_keys}")
43+
for key in expected_pytree.keys():
44+
if key in new_pytree.keys():
45+
try:
46+
expected_pytree_shape = expected_pytree[key].shape
47+
except Exception:
48+
expected_pytree_shape = expected_pytree[key].value.shape
49+
if expected_pytree_shape != new_pytree[key].shape:
50+
max_logging.log(f"shape mismatch for {key}")
51+
max_logging.log(
52+
f"shape mismatch, expected shape of {expected_pytree[key].shape}, but got shape of {new_pytree[key].shape}"
53+
)
54+
else:
55+
max_logging.log(f"key: {key} not found...")
56+
57+
def torch2jax(torch_tensor: torch.Tensor) -> Array:
58+
is_bfloat16 = torch_tensor.dtype == torch.bfloat16
59+
if is_bfloat16:
60+
# upcast the tensor to fp32
61+
torch_tensor = torch_tensor.float()
62+
63+
if torch.device.type != "cpu":
64+
torch_tensor = torch_tensor.to("cpu")
65+
66+
numpy_value = torch_tensor.numpy()
67+
jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None)
68+
return jax_array
3069

3170
def rename_key(key):
3271
regex = r"\w+[.]\d+"
@@ -93,6 +132,12 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
93132
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
94133
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
95134
return renamed_pt_tuple_key, pt_tensor
135+
136+
# 3d conv layer
137+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
138+
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 5:
139+
pt_tensor = pt_tensor.transpose(2, 3, 4, 1, 0)
140+
return renamed_pt_tuple_key, pt_tensor
96141

97142
# linear layer
98143
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
@@ -103,6 +148,8 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
103148
# old PyTorch layer norm weight
104149
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
105150
if pt_tuple_key[-1] == "gamma":
151+
renamed_pt_tuple_key = pt_tuple_key
152+
pt_tensor = pt_tensor.flatten()
106153
return renamed_pt_tuple_key, pt_tensor
107154

108155
# old PyTorch layer norm bias

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ def __init__(
200200
precision: jax.lax.Precision = None,
201201
attention: str = "dot_product",
202202
):
203-
kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size")
204-
stride = _canonicalize_tuple(stride, 3, "stride")
205203
self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs)
206204

207205
def __call__(self, x):
@@ -233,19 +231,19 @@ def __init__(
233231
nnx.Conv(
234232
dim,
235233
dim // 2,
236-
kernel_size=(1, 3, 3),
234+
kernel_size=(3, 3),
237235
padding="SAME",
238236
use_bias=True,
239237
rngs=rngs,
240238
),
241239
)
242240
elif mode == "upsample3d":
243241
self.resample = nnx.Sequential(
244-
WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"),
242+
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
245243
nnx.Conv(
246244
dim,
247245
dim // 2,
248-
kernel_size=(1, 3, 3),
246+
kernel_size=(3, 3),
249247
padding="SAME",
250248
use_bias=True,
251249
rngs=rngs,
@@ -259,11 +257,9 @@ def __init__(
259257
padding=(1, 0, 0),
260258
)
261259
elif mode == "downsample2d":
262-
# TODO - do I need to transpose?
263-
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2))
260+
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2))
264261
elif mode == "downsample3d":
265-
# TODO - do I need to transpose?
266-
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2))
262+
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2))
267263
self.time_conv = WanCausalConv3d(
268264
rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
269265
)
@@ -334,7 +330,6 @@ def __init__(
334330
self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False)
335331
self.conv1 = WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=1)
336332
self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False)
337-
self.dropout = nnx.Dropout(dropout, rngs=rngs)
338333
self.conv2 = WanCausalConv3d(rngs=rngs, in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1)
339334
self.conv_shortcut = (
340335
WanCausalConv3d(rngs=rngs, in_channels=in_dim, out_channels=out_dim, kernel_size=1)
@@ -363,7 +358,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
363358

364359
x = self.norm2(x)
365360
x = self.nonlinearity(x)
366-
x = self.dropout(x)
367361

368362
if feat_cache is not None:
369363
idx = feat_idx[0]
@@ -384,8 +378,8 @@ class WanAttentionBlock(nnx.Module):
384378
def __init__(self, dim: int, rngs: nnx.Rngs):
385379
self.dim = dim
386380
self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False)
387-
self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=1, rngs=rngs)
388-
self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=1, rngs=rngs)
381+
self.to_qkv = nnx.Conv(in_features=dim, out_features=dim * 3, kernel_size=(1, 1), rngs=rngs)
382+
self.proj = nnx.Conv(in_features=dim, out_features=dim, kernel_size=(1, 1), rngs=rngs)
389383

390384
def __call__(self, x: jax.Array):
391385
batch_size, time, height, width, channels = x.shape
@@ -801,8 +795,6 @@ def _encode(self, x: jax.Array):
801795
x = jnp.transpose(x, (0, 2, 3, 4, 1))
802796
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
803797

804-
# self.clear_cache()
805-
806798
t = x.shape[1]
807799
iter_ = 1 + (t - 1) // 4
808800
for i in range(iter_):
@@ -854,8 +846,8 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu
854846
def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]:
855847
if z.shape[-1] != self.z_dim:
856848
# reshape channel last for JAX
857-
x = jnp.transpose(x, (0, 2, 3, 4, 1))
858-
assert x.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {x.shape}"
849+
z = jnp.transpose(z, (0, 2, 3, 4, 1))
850+
assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}"
859851
decoded = self._decode(z).sample
860852
if not return_dict:
861853
return (decoded,)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from maxdiffusion import max_logging
4+
from huggingface_hub import hf_hub_download
5+
from safetensors import safe_open
6+
from flax.traverse_util import flatten_dict, unflatten_dict
7+
from ..modeling_flax_pytorch_utils import (
8+
rename_key,
9+
rename_key_and_reshape_tensor,
10+
torch2jax,
11+
validate_flax_state_dict
12+
)
13+
14+
def _tuple_str_to_int(in_tuple):
15+
out_list = []
16+
for item in in_tuple:
17+
try:
18+
out_list.append(int(item))
19+
except:
20+
out_list.append(item)
21+
return tuple(out_list)
22+
23+
24+
def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
25+
device = jax.devices(device)[0]
26+
with jax.default_device(device):
27+
if hf_download:
28+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors")
29+
#breakpoint()
30+
max_logging.log(f"Load and port Wan 2.1 VAE on {device}")
31+
32+
if ckpt_path is not None:
33+
tensors = {}
34+
with safe_open(ckpt_path, framework="pt") as f:
35+
for k in f.keys():
36+
tensors[k] = torch2jax(f.get_tensor(k))
37+
flax_state_dict = {}
38+
cpu = jax.local_devices(backend="cpu")[0]
39+
for pt_key, tensor in tensors.items():
40+
renamed_pt_key = rename_key(pt_key)
41+
# Order matters
42+
renamed_pt_key = renamed_pt_key.replace("up_blocks_", "up_blocks.")
43+
renamed_pt_key = renamed_pt_key.replace("mid_block_", "mid_block.")
44+
renamed_pt_key = renamed_pt_key.replace("down_blocks_", "down_blocks.")
45+
46+
renamed_pt_key = renamed_pt_key.replace("conv_in.bias", "conv_in.conv.bias")
47+
renamed_pt_key = renamed_pt_key.replace("conv_in.weight", "conv_in.conv.weight")
48+
renamed_pt_key = renamed_pt_key.replace("conv_out.bias", "conv_out.conv.bias")
49+
renamed_pt_key = renamed_pt_key.replace("conv_out.weight", "conv_out.conv.weight")
50+
renamed_pt_key = renamed_pt_key.replace("attentions_", "attentions.")
51+
renamed_pt_key = renamed_pt_key.replace("resnets_", "resnets.")
52+
renamed_pt_key = renamed_pt_key.replace("upsamplers_", "upsamplers.")
53+
renamed_pt_key = renamed_pt_key.replace("resample_", "resample.")
54+
renamed_pt_key = renamed_pt_key.replace("conv1.bias", "conv1.conv.bias")
55+
renamed_pt_key = renamed_pt_key.replace("conv1.weight", "conv1.conv.weight")
56+
renamed_pt_key = renamed_pt_key.replace("conv2.bias", "conv2.conv.bias")
57+
renamed_pt_key = renamed_pt_key.replace("conv2.weight", "conv2.conv.weight")
58+
renamed_pt_key = renamed_pt_key.replace("time_conv.bias", "time_conv.conv.bias")
59+
renamed_pt_key = renamed_pt_key.replace("time_conv.weight", "time_conv.conv.weight")
60+
renamed_pt_key = renamed_pt_key.replace("quant_conv", "quant_conv.conv")
61+
renamed_pt_key = renamed_pt_key.replace("conv_shortcut", "conv_shortcut.conv")
62+
if "decoder" in renamed_pt_key:
63+
renamed_pt_key = renamed_pt_key.replace("resample.1.bias", "resample.layers.1.bias")
64+
renamed_pt_key = renamed_pt_key.replace("resample.1.weight", "resample.layers.1.weight")
65+
if "encoder" in renamed_pt_key:
66+
renamed_pt_key = renamed_pt_key.replace("resample.1", "resample.conv")
67+
pt_tuple_key = tuple(renamed_pt_key.split("."))
68+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes)
69+
flax_key = _tuple_str_to_int(flax_key)
70+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
71+
validate_flax_state_dict(eval_shapes, flax_state_dict)
72+
flax_state_dict = unflatten_dict(flax_state_dict)
73+
del tensors
74+
jax.clear_caches()
75+
else:
76+
raise FileNotFoundError(f"Path {ckpt_path} was not found")
77+
78+
return flax_state_dict

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
WanCausalConv3d,
3030
WanUpsample,
3131
AutoencoderKLWan,
32-
WanEncoder3d,
3332
WanMidBlock,
3433
WanResidualBlock,
3534
WanRMS_norm,
3635
WanResample,
3736
ZeroPaddedConv2D,
3837
WanAttentionBlock,
3938
)
39+
from ..models.wan.wan_utils import load_wan_vae
4040

4141
CACHE_T = 2
4242

@@ -421,6 +421,20 @@ def test_wan_encode(self):
421421
output = wan_vae.encode(input)
422422
assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16)
423423

424+
# def test_load_checkpoint(self):
425+
# pretrained_model_name_or_path = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
426+
# key = jax.random.key(0)
427+
# rngs = nnx.Rngs(key)
428+
# wan_vae = AutoencoderKLWan.from_config(
429+
# pretrained_model_name_or_path,
430+
# subfolder="vae",
431+
# rngs=rngs
432+
# )
433+
# graphdef, state = nnx.split(wan_vae)
434+
# params = state.to_pure_dict()
435+
# # This replaces random params with the model.
436+
# params = load_wan_vae(pretrained_model_name_or_path, params, "cpu")
437+
424438

425439
if __name__ == "__main__":
426440
absltest.main()

src/maxdiffusion/utils/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
is_xformers_available,
8484
requires_backends,
8585
)
86-
from .loading_utils import load_image
86+
from .loading_utils import load_image, load_video
8787
from .logging import get_logger
8888
from .outputs import BaseOutput
8989
from .peft_utils import (
@@ -103,7 +103,6 @@
103103
convert_unet_state_dict_to_peft,
104104
)
105105

106-
107106
logger = get_logger(__name__)
108107

109108

0 commit comments

Comments
 (0)