Skip to content

Commit 359daa8

Browse files
committed
fix
1 parent 3808925 commit 359daa8

1 file changed

Lines changed: 24 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,16 @@ def replace_suffix(lst, old, new):
144144
flax_key = tuple(flax_key_str)
145145
flax_key = _tuple_str_to_int(flax_key)
146146

147+
if "scale_shift_table" in str(flax_key):
148+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (scale_shift_table)")
149+
150+
if "audio_caption_projection" in str(flax_key):
151+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_caption_projection)")
152+
if "audio_time_embed" in str(flax_key):
153+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_time_embed)")
154+
155+
return flax_key, flax_tensor
156+
147157
if scan_layers and block_index is not None:
148158
if "transformer_blocks" in flax_key:
149159
if flax_key in flax_state_dict:
@@ -167,6 +177,11 @@ def replace_suffix(lst, old, new):
167177
if "scale_shift_table" in str(flax_key):
168178
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (scale_shift_table)")
169179

180+
if "audio_caption_projection" in str(flax_key):
181+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_caption_projection)")
182+
if "audio_time_embed" in str(flax_key):
183+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_time_embed)")
184+
170185
return flax_key, flax_tensor
171186

172187
def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
@@ -388,17 +403,22 @@ def load_vae_weights(
388403
current_tensor = flax_state_dict[flax_key]
389404
else:
390405
# Initialize with correct shape from random_flax_state_dict
391-
if flax_key in random_flax_state_dict:
392-
target_shape = random_flax_state_dict[flax_key].shape
406+
# We must use STRING tuple for lookup in random_flax_state_dict
407+
str_flax_key = tuple([str(x) for x in flax_key])
408+
409+
if str_flax_key in random_flax_state_dict:
410+
target_shape = random_flax_state_dict[str_flax_key].shape
393411
current_tensor = jnp.zeros(target_shape, dtype=flax_tensor.dtype)
394412
else:
395413
# Fallback if key missing (shouldn't happen with correct mapping)
396-
print(f"Warning: Key {flax_key} not found in random_flax_state_dict, cannot stack.")
414+
print(f"Warning: Key {str_flax_key} not found in random_flax_state_dict, cannot stack.")
397415
current_tensor = flax_tensor # Might fail shape check later
398416

399417
# Place the tensor at the correct index
400418
# flax_tensor is (..., C), target is (N_resnets, ..., C)
401-
if flax_key in random_flax_state_dict: # Only stack if we have a valid target
419+
420+
str_flax_key = tuple([str(x) for x in flax_key])
421+
if str_flax_key in random_flax_state_dict: # Only stack if we have a valid target
402422
current_tensor = current_tensor.at[resnet_index].set(flax_tensor)
403423
flax_state_dict[flax_key] = current_tensor
404424
else:

0 commit comments

Comments
 (0)