@@ -221,19 +221,39 @@ def load_vae_weights(
221221
222222 for pt_key , tensor in tensors .items ():
223223 renamed_pt_key = rename_key (pt_key )
224- if ".resnets." in renamed_pt_key :
225- # pattern: resnets.0 -> resnets_0
226- # We need to capture the number after resnets
227- import re
228- # Replace resnets.N with resnets_N
229- renamed_pt_key = re .sub (r"resnets\.(\d+)" , r"resnets_\1" , renamed_pt_key )
230-
224+
231225 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
232226
227+ # Handle resnets.N -> resnets with stacking
228+ resnet_index = None
229+ if "resnets" in pt_tuple_key :
230+ pt_list = list (pt_tuple_key )
231+ # Iterate backwards to safely pop
232+ for i in range (len (pt_list ) - 1 , - 1 , - 1 ):
233+ if pt_list [i ] == "resnets" and i + 1 < len (pt_list ) and pt_list [i + 1 ].isdigit ():
234+ resnet_index = int (pt_list [i + 1 ])
235+ pt_list .pop (i + 1 )
236+ break
237+ pt_tuple_key = tuple (pt_list )
238+
233239 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
234240 flax_key = _tuple_str_to_int (flax_key )
235-
236- flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
241+
242+ if resnet_index is not None :
243+ if flax_key in flax_state_dict :
244+ current_tensor = flax_state_dict [flax_key ]
245+ else :
246+ # Initialize with correct shape from random_flax_state_dict
247+ target_shape = random_flax_state_dict [flax_key ].shape
248+ current_tensor = jnp .zeros (target_shape , dtype = flax_tensor .dtype )
249+
250+ # Place the tensor at the correct index
251+ # flax_tensor is (..., C), target is (N_resnets, ..., C)
252+ # We need to ensure dims match for assignment
253+ current_tensor = current_tensor .at [resnet_index ].set (flax_tensor )
254+ flax_state_dict [flax_key ] = current_tensor
255+ else :
256+ flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
237257
238258 validate_flax_state_dict (eval_shapes , flax_state_dict )
239259 flax_state_dict = unflatten_dict (flax_state_dict )
0 commit comments