@@ -224,34 +224,52 @@ def load_vae_weights(
224224
225225 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
226226
227- # Handle resnets.N -> resnets with stacking
227+ pt_list = []
228228 resnet_index = None
229- if "resnets" in pt_tuple_key :
230- pt_list = list (pt_tuple_key )
231- # Iterate backwards to safely pop
232- for i in range (len (pt_list ) - 1 , - 1 , - 1 ):
233- if pt_list [i ] == "resnets" and i + 1 < len (pt_list ) and pt_list [i + 1 ].isdigit ():
234- resnet_index = int (pt_list [i + 1 ])
235- pt_list .pop (i + 1 )
236- break
237- pt_tuple_key = tuple (pt_list )
229+
230+ for part in pt_tuple_key :
231+ # Check for name_N pattern
232+ if "_" in part and part .split ("_" )[- 1 ].isdigit ():
233+ name = "_" .join (part .split ("_" )[:- 1 ])
234+ idx = int (part .split ("_" )[- 1 ])
235+
236+ if name == "resnets" :
237+ resnet_index = idx
238+ pt_list .append ("resnets" )
239+ elif name in ["down_blocks" , "up_blocks" , "downsamplers" , "upsamplers" ]:
240+ pt_list .append (name )
241+ pt_list .append (idx )
242+ else :
243+ pt_list .append (part )
244+ else :
245+ pt_list .append (part )
246+
247+ pt_tuple_key = tuple (pt_list )
238248
239249 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
250+ # _tuple_str_to_int might not be needed if we already injected ints, but it's safe
240251 flax_key = _tuple_str_to_int (flax_key )
241252
242253 if resnet_index is not None :
243254 if flax_key in flax_state_dict :
244255 current_tensor = flax_state_dict [flax_key ]
245256 else :
246257 # Initialize with correct shape from random_flax_state_dict
247- target_shape = random_flax_state_dict [flax_key ].shape
248- current_tensor = jnp .zeros (target_shape , dtype = flax_tensor .dtype )
258+ if flax_key in random_flax_state_dict :
259+ target_shape = random_flax_state_dict [flax_key ].shape
260+ current_tensor = jnp .zeros (target_shape , dtype = flax_tensor .dtype )
261+ else :
262+ # Fallback if key missing (shouldn't happen with correct mapping)
263+ print (f"Warning: Key { flax_key } not found in random_flax_state_dict, cannot stack." )
264+ current_tensor = flax_tensor # Might fail shape check later
249265
250266 # Place the tensor at the correct index
251267 # flax_tensor is (..., C), target is (N_resnets, ..., C)
252- # We need to ensure dims match for assignment
253- current_tensor = current_tensor .at [resnet_index ].set (flax_tensor )
254- flax_state_dict [flax_key ] = current_tensor
268+ if flax_key in random_flax_state_dict : # Only stack if we have a valid target
269+ current_tensor = current_tensor .at [resnet_index ].set (flax_tensor )
270+ flax_state_dict [flax_key ] = current_tensor
271+ else :
272+ flax_state_dict [flax_key ] = flax_tensor
255273 else :
256274 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
257275
0 commit comments