@@ -266,16 +266,9 @@ def load_vae_weights(
266266 flattened_eval = flatten_dict (eval_shapes )
267267
268268 random_flax_state_dict = {}
269- print (f"DEBUG: eval_shapes length (flattened): { len (flattened_eval )} " )
270- sample_target_found = False
271269 for key in flattened_eval :
272270 string_tuple = tuple ([str (item ) for item in key ])
273271 random_flax_state_dict [string_tuple ] = flattened_eval [key ]
274-
275- if not sample_target_found and "decoder" in string_tuple and "up_blocks" in string_tuple and "0" in string_tuple and "resnets" in string_tuple and "2" in string_tuple and "conv2" in string_tuple :
276- print (f"DEBUG: Found target key in eval_shapes: { key } " )
277- print (f"DEBUG: Key types: { [type (x ) for x in key ]} " )
278- sample_target_found = True
279272
280273 for pt_key , tensor in tensors .items ():
281274 renamed_pt_key = rename_key (pt_key )
@@ -347,11 +340,21 @@ def load_vae_weights(
347340 else :
348341 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
349342
350- if "decoder" in flax_key_str and "up_blocks" in flax_key_str and "0" in flax_key_str and "resnets" in flax_key_str and "2" in flax_key_str and "conv2" in flax_key_str :
351- print (f"DEBUG: Processing target key. Final flax_key: { flax_key } " )
343+
344+
345+ # Filter out non-parameter keys for validation
346+ filtered_eval_shapes = {}
347+ for k , v in flattened_eval .items ():
348+ # flax key is a tuple of strings/ints
349+ k_str = [str (x ) for x in k ]
350+ if "dropout" in k_str or "rngs" in k_str :
351+ continue
352+ filtered_eval_shapes [k ] = v
352353
353354 print (f"Total VAE keys loaded: { len (flax_state_dict )} " )
354- validate_flax_state_dict (eval_shapes , flax_state_dict )
355+
356+ # Unflatten to pass to validate_flax_state_dict which expects a pytree
357+ validate_flax_state_dict (unflatten_dict (filtered_eval_shapes ), flax_state_dict )
355358 flax_state_dict = unflatten_dict (flax_state_dict )
356359 del tensors
357360 jax .clear_caches ()
0 commit comments