@@ -203,8 +203,12 @@ def load_vae_weights(
203203 tensors [k ] = torch2jax (f .get_tensor (k ))
204204 else :
205205 loaded_state_dict = torch .load (ckpt_path , map_location = "cpu" )
206- for k , v in loaded_state_dict .items ():
206+ for k , v in loaded_state_dict .items ():
207207 tensors [k ] = torch2jax (v )
208+
209+ print ("\n DEBUG: Top 20 keys from VAE Checkpoint (tensors):" )
210+ for k in list (tensors .keys ())[:20 ]:
211+ print (k )
208212
209213 flax_state_dict = {}
210214 cpu = jax .local_devices (backend = "cpu" )[0 ]
@@ -223,7 +227,7 @@ def load_vae_weights(
223227 pt_list = []
224228 resnet_index = None
225229
226- for part in pt_tuple_key :
230+ for i , part in enumerate ( pt_tuple_key ) :
227231 # Check for name_N pattern
228232 if "_" in part and part .split ("_" )[- 1 ].isdigit ():
229233 name = "_" .join (part .split ("_" )[:- 1 ])
@@ -237,9 +241,14 @@ def load_vae_weights(
237241 pt_list .append (str (idx ))
238242 else :
239243 pt_list .append (part )
240- elif part in ["conv1" , "conv2" , "conv_in" , "conv_out" , "conv_shortcut" , " conv" ]:
244+ elif part in ["conv1" , "conv2" , "conv" ]:
241245 pt_list .append (part )
242- pt_list .append ("conv" )
246+ # Only inject 'conv' if it's not already there
247+ # Check if next part is 'conv'
248+ if i + 1 < len (pt_tuple_key ) and pt_tuple_key [i + 1 ] == "conv" :
249+ pass # already has conv
250+ else :
251+ pt_list .append ("conv" )
243252 else :
244253 pt_list .append (part )
245254
@@ -248,6 +257,9 @@ def load_vae_weights(
248257 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
249258 # _tuple_str_to_int might not be needed if we already injected ints, but it's safe
250259 flax_key = _tuple_str_to_int (flax_key )
260+
261+ if flax_key == ("latents_mean" ,) or flax_key == ("latents_std" ,):
262+ continue # Skip stats
251263
252264 if resnet_index is not None :
253265 if flax_key in flax_state_dict :
0 commit comments