@@ -144,6 +144,16 @@ def replace_suffix(lst, old, new):
144144 flax_key = tuple (flax_key_str )
145145 flax_key = _tuple_str_to_int (flax_key )
146146
147+ if "scale_shift_table" in str (flax_key ):
148+ print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (scale_shift_table)" )
149+
150+ if "audio_caption_projection" in str (flax_key ):
151+ print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_caption_projection)" )
152+ if "audio_time_embed" in str (flax_key ):
153+ print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_time_embed)" )
154+
155+ return flax_key , flax_tensor
156+
147157 if scan_layers and block_index is not None :
148158 if "transformer_blocks" in flax_key :
149159 if flax_key in flax_state_dict :
@@ -167,6 +177,11 @@ def replace_suffix(lst, old, new):
167177 if "scale_shift_table" in str (flax_key ):
168178 print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (scale_shift_table)" )
169179
180+ if "audio_caption_projection" in str (flax_key ):
181+ print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_caption_projection)" )
182+ if "audio_time_embed" in str (flax_key ):
183+ print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_time_embed)" )
184+
170185 return flax_key , flax_tensor
171186
172187def load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device ):
@@ -388,17 +403,22 @@ def load_vae_weights(
388403 current_tensor = flax_state_dict [flax_key ]
389404 else :
390405 # Initialize with correct shape from random_flax_state_dict
391- if flax_key in random_flax_state_dict :
392- target_shape = random_flax_state_dict [flax_key ].shape
406+ # We must use STRING tuple for lookup in random_flax_state_dict
407+ str_flax_key = tuple ([str (x ) for x in flax_key ])
408+
409+ if str_flax_key in random_flax_state_dict :
410+ target_shape = random_flax_state_dict [str_flax_key ].shape
393411 current_tensor = jnp .zeros (target_shape , dtype = flax_tensor .dtype )
394412 else :
395413 # Fallback if key missing (shouldn't happen with correct mapping)
396- print (f"Warning: Key { flax_key } not found in random_flax_state_dict, cannot stack." )
414+ print (f"Warning: Key { str_flax_key } not found in random_flax_state_dict, cannot stack." )
397415 current_tensor = flax_tensor # Might fail shape check later
398416
399417 # Place the tensor at the correct index
400418 # flax_tensor is (..., C), target is (N_resnets, ..., C)
401- if flax_key in random_flax_state_dict : # Only stack if we have a valid target
419+
420+ str_flax_key = tuple ([str (x ) for x in flax_key ])
421+ if str_flax_key in random_flax_state_dict : # Only stack if we have a valid target
402422 current_tensor = current_tensor .at [resnet_index ].set (flax_tensor )
403423 flax_state_dict [flax_key ] = current_tensor
404424 else :
0 commit comments