|
| 1 | + |
| 2 | +import jax |
| 3 | +import sys |
| 4 | +import os |
| 5 | + |
| 6 | +# Add src to path |
| 7 | +sys.path.append(os.path.join(os.getcwd(), "src")) |
| 8 | + |
| 9 | +from maxdiffusion.models.ltx2.audio_vae import FlaxAutoencoderKLLTX2Audio |
| 10 | +from maxdiffusion.models.ltx2.ltx2_utils import load_audio_vae_weights, rename_for_ltx2_audio_vae |
| 11 | +from maxdiffusion.utils import load_sharded_checkpoint |
| 12 | +from flax import nnx |
| 13 | + |
| 14 | +def debug_keys(): |
| 15 | + print("Initializing Model...") |
| 16 | + config = { |
| 17 | + "base_channels": 128, |
| 18 | + "ch_mult": (1, 2, 4), |
| 19 | + "double_z": True, |
| 20 | + "dropout": 0.0, |
| 21 | + "in_channels": 2, |
| 22 | + "latent_channels": 8, |
| 23 | + "mel_bins": 64, |
| 24 | + "mel_hop_length": 160, |
| 25 | + "mid_block_add_attention": False, |
| 26 | + "norm_type": "pixel", |
| 27 | + "num_res_blocks": 2, |
| 28 | + "output_channels": 2, |
| 29 | + "resolution": 256, |
| 30 | + "sample_rate": 16000, |
| 31 | + "rngs": nnx.Rngs(0) |
| 32 | + } |
| 33 | + |
| 34 | + with jax.default_device(jax.devices("cpu")[0]): |
| 35 | + model = FlaxAutoencoderKLLTX2Audio(**config) |
| 36 | + |
| 37 | + state = nnx.state(model) |
| 38 | + eval_shapes = state.to_pure_dict() |
| 39 | + |
| 40 | + # Print some expected Flax keys |
| 41 | + print("\nSample Flax Keys (Expected):") |
| 42 | + |
| 43 | + def flatten(d, parent_key=()): |
| 44 | + items = [] |
| 45 | + for k, v in d.items(): |
| 46 | + new_key = parent_key + (k,) |
| 47 | + if isinstance(v, dict): |
| 48 | + items.extend(flatten(v, new_key)) |
| 49 | + else: |
| 50 | + items.append(new_key) |
| 51 | + return items |
| 52 | + |
| 53 | + flax_keys = flatten(eval_shapes) |
| 54 | + for k in flax_keys[:20]: |
| 55 | + print(k) |
| 56 | + |
| 57 | + print("\nTotal Flax Keys:", len(flax_keys)) |
| 58 | + |
| 59 | + # Load PyTorch keys |
| 60 | + print("\nLoading PyTorch SafeTensors Keys...") |
| 61 | + pretrained_model_name_or_path = "Lightricks/LTX-2" |
| 62 | + subfolder = "audio_vae" |
| 63 | + |
| 64 | + tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, "cpu") |
| 65 | + pt_keys = list(tensors.keys()) |
| 66 | + |
| 67 | + print("\nSample PyTorch Keys (Original):") |
| 68 | + for k in pt_keys[:20]: |
| 69 | + print(k) |
| 70 | + |
| 71 | + print("\nTesting Renaming Logic...") |
| 72 | + renamed_keys = [] |
| 73 | + for k in pt_keys: |
| 74 | + renamed = rename_for_ltx2_audio_vae(k) |
| 75 | + renamed_keys.append(renamed) |
| 76 | + if "mid_block.resnets.0.conv1.weight" in k: |
| 77 | + print(f"Renaming check: {k} -> {renamed}") |
| 78 | + |
| 79 | + # Check for misaligned expected keys |
| 80 | + # specific missing ones |
| 81 | + targets = [ |
| 82 | + ('decoder', 'mid_block1', 'conv1', 'conv', 'bias'), |
| 83 | + ('decoder', 'mid_block1', 'conv1', 'conv', 'kernel'), |
| 84 | + ] |
| 85 | + |
| 86 | + print("\nSearching for targets in RENAMED keys:") |
| 87 | + for t in targets: |
| 88 | + t_str = ".".join([str(x) for x in t]) |
| 89 | + found = False |
| 90 | + for rk in renamed_keys: |
| 91 | + # We need to simulate the structure mapping logic too? |
| 92 | + # rename_for_ltx2_audio_vae only does string replacement, |
| 93 | + # load_audio_vae_weights does structural mapping (mid_block -> mid_block1) |
| 94 | + pass |
| 95 | + |
| 96 | + # Let's verify specific renaming for mid_block1 |
| 97 | + # PyTorch: decoder.mid_block.resnets.0.conv1.weight |
| 98 | + # My rename: decoder.mid_block.resnets.0.conv1.conv.kernel |
| 99 | + # My logic in load_audio_vae_weights: |
| 100 | + # if "mid_block.resnets.0" in k: replace with mid_block1 |
| 101 | + # -> decoder.mid_block1.conv1.conv.kernel |
| 102 | + # Flax expected: ('decoder', 'mid_block1', 'conv1', 'conv', 'kernel') |
| 103 | + |
| 104 | + # Is it possible that 'mid_block.resnets.0' is NOT in the key? |
| 105 | + # Maybe it's 'mid_block.resnets.0.conv1.weight'? Yes. |
| 106 | + |
| 107 | + # We will print all RENAMED and STRUCTURED keys produced by our logic |
| 108 | + print("\nGenerating final Flax keys from PyTorch keys using current logic...") |
| 109 | + final_keys = set() |
| 110 | + |
| 111 | + for pt_key in pt_keys: |
| 112 | + key = rename_for_ltx2_audio_vae(pt_key) |
| 113 | + |
| 114 | + # Determine conversion to tuple (Same logic as in ltx2_utils.py) |
| 115 | + parts = key.split(".") |
| 116 | + flax_key_parts = [] |
| 117 | + for part in parts: |
| 118 | + if part.isdigit(): |
| 119 | + flax_key_parts.append(int(part)) |
| 120 | + else: |
| 121 | + flax_key_parts.append(part) |
| 122 | + flax_key = tuple(flax_key_parts) |
| 123 | + |
| 124 | + if "mid_block" in pt_key: |
| 125 | + if "mid_block.resnets.0" in pt_key: |
| 126 | + flax_key_str = ".".join([str(x) for x in flax_key]) |
| 127 | + flax_key_str = flax_key_str.replace("mid_block.resnets.0", "mid_block1") |
| 128 | + elif "mid_block.resnets.1" in pt_key: |
| 129 | + flax_key_str = ".".join([str(x) for x in flax_key]) |
| 130 | + flax_key_str = flax_key_str.replace("mid_block.resnets.1", "mid_block2") |
| 131 | + elif "mid_block.attentions.0" in pt_key: |
| 132 | + flax_key_str = ".".join([str(x) for x in flax_key]) |
| 133 | + flax_key_str = flax_key_str.replace("mid_block.attentions.0", "mid_attn") |
| 134 | + else: |
| 135 | + flax_key_str = ".".join([str(x) for x in flax_key]) |
| 136 | + |
| 137 | + parts = flax_key_str.split(".") |
| 138 | + flax_key_parts = [] |
| 139 | + for part in parts: |
| 140 | + if part.isdigit(): |
| 141 | + flax_key_parts.append(int(part)) |
| 142 | + else: |
| 143 | + flax_key_parts.append(part) |
| 144 | + flax_key = tuple(flax_key_parts) |
| 145 | + |
| 146 | + if "down_blocks" in key: |
| 147 | + key_str = ".".join([str(x) for x in flax_key]) |
| 148 | + if "resnets" in key_str: |
| 149 | + key_str = key_str.replace("down_blocks", "down_stages") |
| 150 | + key_str = key_str.replace("resnets", "blocks") |
| 151 | + elif "attentions" in key_str: |
| 152 | + key_str = key_str.replace("down_blocks", "down_stages") |
| 153 | + key_str = key_str.replace("attentions", "attns") |
| 154 | + elif "downsamplers" in key_str: |
| 155 | + key_str = key_str.replace("down_blocks", "down_stages") |
| 156 | + key_str = key_str.replace("downsamplers.0", "downsample") |
| 157 | + |
| 158 | + parts = key_str.split(".") |
| 159 | + flax_key_parts = [] |
| 160 | + for part in parts: |
| 161 | + if part.isdigit(): |
| 162 | + flax_key_parts.append(int(part)) |
| 163 | + else: |
| 164 | + flax_key_parts.append(part) |
| 165 | + flax_key = tuple(flax_key_parts) |
| 166 | + |
| 167 | + if "up_blocks" in key: |
| 168 | + key_str = ".".join([str(x) for x in flax_key]) |
| 169 | + if "resnets" in key_str: |
| 170 | + key_str = key_str.replace("up_blocks", "up_stages") |
| 171 | + key_str = key_str.replace("resnets", "blocks") |
| 172 | + elif "attentions" in key_str: |
| 173 | + key_str = key_str.replace("up_blocks", "up_stages") |
| 174 | + key_str = key_str.replace("attentions", "attns") |
| 175 | + elif "upsamplers" in key_str: |
| 176 | + key_str = key_str.replace("up_blocks", "up_stages") |
| 177 | + key_str = key_str.replace("upsamplers.0", "upsample") |
| 178 | + |
| 179 | + parts = key_str.split(".") |
| 180 | + flax_key_parts = [] |
| 181 | + for part in parts: |
| 182 | + if part.isdigit(): |
| 183 | + flax_key_parts.append(int(part)) |
| 184 | + else: |
| 185 | + flax_key_parts.append(part) |
| 186 | + flax_key = tuple(flax_key_parts) |
| 187 | + |
| 188 | + final_keys.add(flax_key) |
| 189 | + |
| 190 | + print("\nComparing Final Keys vs Expected Keys...") |
| 191 | + flax_keys_set = set(flax_keys) |
| 192 | + missing = flax_keys_set - final_keys |
| 193 | + |
| 194 | + # Filter stats |
| 195 | + filtered_missing = [] |
| 196 | + for k in missing: |
| 197 | + k_str = [str(x) for x in k] |
| 198 | + if "dropout" in k_str or "rngs" in k_str: |
| 199 | + continue |
| 200 | + filtered_missing.append(k) |
| 201 | + |
| 202 | + print(f"Missing Keys (Count: {len(filtered_missing)}):") |
| 203 | + for k in sorted(filtered_missing)[:20]: |
| 204 | + print(k) |
| 205 | + |
| 206 | + print("\nExtra Keys (Count: {len(final_keys - flax_keys_set)}):") |
| 207 | + for k in sorted(list(final_keys - flax_keys_set))[:20]: |
| 208 | + print(k) |
| 209 | + |
| 210 | +if __name__ == "__main__": |
| 211 | + debug_keys() |
0 commit comments