Skip to content

Commit 4c432ea

Browse files
committed
ltx2_utils change for weight loading from a single safetensors file
1 parent c86ac1f commit 4c432ea

1 file changed

Lines changed: 42 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,42 @@
2626
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
2727

2828

29+
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
30+
# Encoder
31+
"down_blocks.0": "down_blocks.0",
32+
"down_blocks.1": "down_blocks.0.downsamplers.0",
33+
"down_blocks.2": "down_blocks.1",
34+
"down_blocks.3": "down_blocks.1.downsamplers.0",
35+
"down_blocks.4": "down_blocks.2",
36+
"down_blocks.5": "down_blocks.2.downsamplers.0",
37+
"down_blocks.6": "down_blocks.3",
38+
"down_blocks.7": "down_blocks.3.downsamplers.0",
39+
"down_blocks.8": "mid_block",
40+
# Decoder
41+
"up_blocks.0": "mid_block",
42+
"up_blocks.1": "up_blocks.0.upsamplers.0",
43+
"up_blocks.2": "up_blocks.0",
44+
"up_blocks.3": "up_blocks.1.upsamplers.0",
45+
"up_blocks.4": "up_blocks.1",
46+
"up_blocks.5": "up_blocks.2.upsamplers.0",
47+
"up_blocks.6": "up_blocks.2",
48+
"last_time_embedder": "time_embedder",
49+
"last_scale_shift_table": "scale_shift_table",
50+
# Common
51+
# For all 3D ResNets
52+
"res_blocks": "resnets",
53+
"per_channel_statistics.mean-of-means": "latents_mean",
54+
"per_channel_statistics.std-of-means": "latents_std",
55+
}
56+
57+
LTX_2_3_VIDEO_VAE_RENAME_DICT = {
58+
**LTX_2_0_VIDEO_VAE_RENAME_DICT,
59+
# Decoder extra blocks
60+
"up_blocks.7": "up_blocks.3.upsamplers.0",
61+
"up_blocks.8": "up_blocks.3",
62+
}
63+
64+
2965
def _tuple_str_to_int(in_tuple):
3066
out_list = []
3167
for item in in_tuple:
@@ -225,7 +261,12 @@ def load_vae_weights(
225261

226262
for pt_key, tensor in tensors.items():
227263
# latents_mean and latents_std are nnx.Params and will be loaded correctly.
228-
renamed_pt_key = rename_key(pt_key)
264+
new_key = pt_key
265+
if filename == "ltx-2.3-22b-dev.safetensors":
266+
for replace_key, rename_to in LTX_2_3_VIDEO_VAE_RENAME_DICT.items():
267+
new_key = new_key.replace(replace_key, rename_to)
268+
269+
renamed_pt_key = rename_key(new_key)
229270
renamed_pt_key = renamed_pt_key.replace("nin_shortcut", "conv_shortcut")
230271

231272
pt_tuple_key = tuple(renamed_pt_key.split("."))

0 commit comments

Comments
 (0)