Skip to content

Commit 9a66fbe

Browse files
committed
ltx2_utils change for weight loading from a single safetensors file
1 parent 4c432ea commit 9a66fbe

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ def load_vae_weights(
259259
for key in flattened_eval:
260260
random_flax_state_dict[tuple(str(item) for item in key)] = flattened_eval[key]
261261

262+
needs_vae_prefix = any(key[0] == "vae" for key in random_flax_state_dict)
263+
262264
for pt_key, tensor in tensors.items():
263265
# latents_mean and latents_std are nnx.Params and will be loaded correctly.
264266
new_key = pt_key
@@ -270,6 +272,8 @@ def load_vae_weights(
270272
renamed_pt_key = renamed_pt_key.replace("nin_shortcut", "conv_shortcut")
271273

272274
pt_tuple_key = tuple(renamed_pt_key.split("."))
275+
if needs_vae_prefix and pt_tuple_key[0] != "vae":
276+
pt_tuple_key = ("vae",) + pt_tuple_key
273277

274278
pt_list = []
275279
resnet_index = None

0 commit comments

Comments
 (0)