Skip to content

Commit d3cf821

Browse files
committed
final cleanup check
1 parent 2041235 commit d3cf821

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def load_vae_weights(
272272

273273
for pt_key, tensor in tensors.items():
274274
renamed_pt_key = rename_key(pt_key)
275+
renamed_pt_key = renamed_pt_key.replace("nin_shortcut", "conv_shortcut")
275276

276277
pt_tuple_key = tuple(renamed_pt_key.split("."))
277278

@@ -295,7 +296,7 @@ def load_vae_weights(
295296
pt_list.append(part)
296297
elif part == "upsampler":
297298
pt_list.append("upsampler")
298-
elif part in ["conv1", "conv2", "conv", "conv_in", "conv_out"]:
299+
elif part in ["conv1", "conv2", "conv", "conv_in", "conv_out", "conv_shortcut"]:
299300
pt_list.append(part)
300301
if i + 1 < len(pt_tuple_key) and pt_tuple_key[i+1] == "conv":
301302
pass
@@ -316,8 +317,6 @@ def load_vae_weights(
316317
flax_key = _tuple_str_to_int(flax_key)
317318

318319
flax_key_str = [str(x) for x in flax_key]
319-
if "conv" in flax_key_str or "bias" in flax_key_str:
320-
pass
321320

322321
if resnet_index is not None:
323322
if flax_key in flax_state_dict:

0 commit comments

Comments
 (0)