@@ -238,6 +238,11 @@ def load_transformer_weights(
238238 cpu = jax .local_devices (backend = "cpu" )[0 ]
239239 flattened_dict = flatten_dict (eval_shapes )
240240
241+ random_flax_state_dict = {}
242+ for key in flattened_dict :
243+ string_tuple = tuple ([str (item ) for item in key ])
244+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
245+
241246 # DEBUG: Print keys to understand mapping
242247 print ("DEBUG: Top 20 keys from Checkpoint (tensors):" )
243248 for k in list (tensors .keys ())[:20 ]:
@@ -342,14 +347,16 @@ def load_vae_weights(
342347 if name == "resnets" :
343348 resnet_index = idx
344349 pt_list .append ("resnets" )
345- elif name in ["down_blocks" , "up_blocks" , "downsamplers" , "upsamplers" ]:
350+ elif name == "upsamplers" :
351+ pt_list .append ("upsampler" )
352+ # Skip the index 0 for upsampler as Flax uses singular non-list
353+ elif name in ["down_blocks" , "up_blocks" , "downsamplers" ]:
346354 pt_list .append (name )
347355 pt_list .append (str (idx ))
348356 else :
349357 pt_list .append (part )
350358 elif part == "upsampler" :
351- pt_list .append ("upsamplers" )
352- pt_list .append ("0" )
359+ pt_list .append ("upsampler" )
353360 elif part in ["conv1" , "conv2" , "conv" ]:
354361 pt_list .append (part )
355362 # Inject 'conv' if it's not already there AND not just added
0 commit comments