Skip to content

Commit 0b86f50

Browse files
committed
weight loading
1 parent 758cb92 commit 0b86f50

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ def load_vae_weights(
274274
needs_vae_prefix = any(key[0] == "vae" for key in random_flax_state_dict)
275275

276276
for pt_key, tensor in tensors.items():
277+
# Filter keys for combined checkpoint to avoid noise and memory overhead
278+
if filename == "ltx-2.3-22b-dev.safetensors":
279+
if not (pt_key.startswith("vae.") or pt_key.startswith("audio_vae.")):
280+
continue
281+
277282
# latents_mean and latents_std are nnx.Params and will be loaded correctly.
278283
new_key = pt_key
279284
if filename == "ltx-2.3-22b-dev.safetensors":
@@ -295,7 +300,7 @@ def load_vae_weights(
295300
name = "_".join(part.split("_")[:-1])
296301
idx = int(part.split("_")[-1])
297302

298-
if name == "resnets":
303+
if name == "resnets" or name == "block":
299304
pt_list.append("resnets")
300305
resnet_index = idx
301306
elif name == "upsamplers":
@@ -322,7 +327,7 @@ def load_vae_weights(
322327

323328
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
324329
flax_key = _tuple_str_to_int(flax_key)
325-
max_logging.log(f"Mapped VAE key: {pt_key} -> {flax_key}")
330+
max_logging.log(f"Mapped key: {pt_key} -> {flax_key}")
326331

327332
if resnet_index is not None:
328333
str_flax_key = tuple([str(x) for x in flax_key])

0 commit comments

Comments
 (0)