File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -340,6 +340,9 @@ def load_vae_weights(
340340 else :
341341 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
342342
343+ # print(f"Loaded VAE Key: {flax_key}")
344+
345+ print (f"Total VAE keys loaded: { len (flax_state_dict )} " )
343346 validate_flax_state_dict (eval_shapes , flax_state_dict )
344347 flax_state_dict = unflatten_dict (flax_state_dict )
345348 del tensors
Original file line number Diff line number Diff line change @@ -108,7 +108,8 @@ def test_load_vae_weights(self):
108108 print ("Validating VAE Weights..." )
109109 # Filter out dropout/rngs keys from eval_shapes as they are not expected in weights
110110 filtered_eval_shapes = {}
111- for k , v in eval_shapes .items ():
111+ flat_eval_shapes = flatten_dict (eval_shapes )
112+ for k , v in flat_eval_shapes .items ():
112113 k_str = [str (x ) for x in k ]
113114 if "dropout" in k_str or "rngs" in k_str :
114115 continue
You can’t perform that action at this time.
0 commit comments