File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -67,6 +67,28 @@ def flatten(d, parent_key=()):
6767 for k in pt_keys [:20 ]:
6868 print (k )
6969
70+ print ("\n Sample Encoder Keys:" )
71+ enc_keys = [k for k in pt_keys if "encoder" in k ]
72+ for k in enc_keys [:20 ]:
73+ print (k )
74+
75+ # Check specific encoder key shape if possible?
76+ # Can't easily check shape here without loading tensor, but load_sharded_checkpoint loads all.
77+ # tensors is already loaded.
78+
79+ print ("\n Checking Encoder Down Block 0 shape:" )
80+ if "encoder.down.0.block.0.conv1.conv.weight" in tensors :
81+ print ("encoder.down.0.block.0.conv1.conv.weight:" , tensors ["encoder.down.0.block.0.conv1.conv.weight" ].shape )
82+ if "encoder.down.0.block.1.conv1.conv.weight" in tensors :
83+ print ("encoder.down.0.block.1.conv1.conv.weight:" , tensors ["encoder.down.0.block.1.conv1.conv.weight" ].shape )
84+ if "encoder.down.1.block.0.conv1.conv.weight" in tensors :
85+ print ("encoder.down.1.block.0.conv1.conv.weight:" , tensors ["encoder.down.1.block.0.conv1.conv.weight" ].shape )
86+
87+ print ("\n Checking Decoder Up Block 0 shape:" )
88+ if "decoder.up.0.block.0.conv1.conv.weight" in tensors :
89+ print ("decoder.up.0.block.0.conv1.conv.weight:" , tensors ["decoder.up.0.block.0.conv1.conv.weight" ].shape )
90+
91+
7092 print ("\n Testing Renaming Logic..." )
7193 renamed_keys = []
7294 for k in pt_keys :
You can’t perform that action at this time.
0 commit comments