We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4c432ea commit 9a66fbeCopy full SHA for 9a66fbe
1 file changed
src/maxdiffusion/models/ltx2/ltx2_utils.py
@@ -259,6 +259,8 @@ def load_vae_weights(
259
for key in flattened_eval:
260
random_flax_state_dict[tuple(str(item) for item in key)] = flattened_eval[key]
261
262
+ needs_vae_prefix = any(key[0] == "vae" for key in random_flax_state_dict)
263
+
264
for pt_key, tensor in tensors.items():
265
# latents_mean and latents_std are nnx.Params and will be loaded correctly.
266
new_key = pt_key
@@ -270,6 +272,8 @@ def load_vae_weights(
270
272
renamed_pt_key = renamed_pt_key.replace("nin_shortcut", "conv_shortcut")
271
273
274
pt_tuple_key = tuple(renamed_pt_key.split("."))
275
+ if needs_vae_prefix and pt_tuple_key[0] != "vae":
276
+ pt_tuple_key = ("vae",) + pt_tuple_key
277
278
pt_list = []
279
resnet_index = None
0 commit comments