Skip to content

Commit 513cf1a

Browse files
committed
fix
1 parent 80e5b52 commit 513cf1a

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ def load_transformer_weights(
238238
cpu = jax.local_devices(backend="cpu")[0]
239239
flattened_dict = flatten_dict(eval_shapes)
240240

241+
random_flax_state_dict = {}
242+
for key in flattened_dict:
243+
string_tuple = tuple([str(item) for item in key])
244+
random_flax_state_dict[string_tuple] = flattened_dict[key]
245+
241246
# DEBUG: Print keys to understand mapping
242247
print("DEBUG: Top 20 keys from Checkpoint (tensors):")
243248
for k in list(tensors.keys())[:20]:
@@ -342,14 +347,16 @@ def load_vae_weights(
342347
if name == "resnets":
343348
resnet_index = idx
344349
pt_list.append("resnets")
345-
elif name in ["down_blocks", "up_blocks", "downsamplers", "upsamplers"]:
350+
elif name == "upsamplers":
351+
pt_list.append("upsampler")
352+
# Skip the index 0 for upsampler as Flax uses singular non-list
353+
elif name in ["down_blocks", "up_blocks", "downsamplers"]:
346354
pt_list.append(name)
347355
pt_list.append(str(idx))
348356
else:
349357
pt_list.append(part)
350358
elif part == "upsampler":
351-
pt_list.append("upsamplers")
352-
pt_list.append("0")
359+
pt_list.append("upsampler")
353360
elif part in ["conv1", "conv2", "conv"]:
354361
pt_list.append(part)
355362
# Inject 'conv' if it's not already there AND not just added

0 commit comments

Comments
 (0)