Skip to content

Commit e2361cf

Browse files
committed
fix
1 parent 07bbd6b commit e2361cf

1 file changed

Lines changed: 16 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,12 @@ def load_vae_weights(
203203
tensors[k] = torch2jax(f.get_tensor(k))
204204
else:
205205
loaded_state_dict = torch.load(ckpt_path, map_location="cpu")
206-
for k, v in loaded_state_dict.items():
206+
for k, v in loaded_state_dict.items():
207207
tensors[k] = torch2jax(v)
208+
209+
print("\nDEBUG: Top 20 keys from VAE Checkpoint (tensors):")
210+
for k in list(tensors.keys())[:20]:
211+
print(k)
208212

209213
flax_state_dict = {}
210214
cpu = jax.local_devices(backend="cpu")[0]
@@ -223,7 +227,7 @@ def load_vae_weights(
223227
pt_list = []
224228
resnet_index = None
225229

226-
for part in pt_tuple_key:
230+
for i, part in enumerate(pt_tuple_key):
227231
# Check for name_N pattern
228232
if "_" in part and part.split("_")[-1].isdigit():
229233
name = "_".join(part.split("_")[:-1])
@@ -237,9 +241,14 @@ def load_vae_weights(
237241
pt_list.append(str(idx))
238242
else:
239243
pt_list.append(part)
240-
elif part in ["conv1", "conv2", "conv_in", "conv_out", "conv_shortcut", "conv"]:
244+
elif part in ["conv1", "conv2", "conv"]:
241245
pt_list.append(part)
242-
pt_list.append("conv")
246+
# Only inject 'conv' if it's not already there
247+
# Check if next part is 'conv'
248+
if i + 1 < len(pt_tuple_key) and pt_tuple_key[i+1] == "conv":
249+
pass # already has conv
250+
else:
251+
pt_list.append("conv")
243252
else:
244253
pt_list.append(part)
245254

@@ -248,6 +257,9 @@ def load_vae_weights(
248257
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
249258
# _tuple_str_to_int might not be needed if we already injected ints, but it's safe
250259
flax_key = _tuple_str_to_int(flax_key)
260+
261+
if flax_key == ("latents_mean",) or flax_key == ("latents_std",):
262+
continue # Skip stats
251263

252264
if resnet_index is not None:
253265
if flax_key in flax_state_dict:

0 commit comments

Comments
 (0)