Skip to content

Commit 9e4038b

Browse files
committed
removed extra files and comments
1 parent 6b2726c commit 9e4038b

4 files changed

Lines changed: 6 additions & 258 deletions

File tree

debug_eval_shapes.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

reproduce_key_mapping.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 6 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -30,54 +30,17 @@ def rename_for_ltx2_transformer(key):
3030
key = key.replace("patchify_proj", "proj_in")
3131
key = key.replace("audio_patchify_proj", "audio_proj_in")
3232
key = key.replace("norm_final", "norm_out")
33-
34-
# Handle scale_shift_table
35-
# PyTorch: adaLN_modulation.1.weight/bias -> scale_shift_table
36-
# rename_key changes adaLN_modulation.1 -> adaLN_modulation_1
3733
if "adaLN_modulation_1" in key:
3834
key = key.replace("adaLN_modulation_1", "scale_shift_table")
3935

4036
if "caption_modulator_1" in key:
4137
key = key.replace("caption_modulator_1", "video_a2v_cross_attn_scale_shift_table")
42-
43-
# Audio caption modulator?
44-
# Checkpoint: audio_caption_modulator.1.weight (Guessing name)
45-
# Let's inspect checkpoint keys for clues if this guess fails.
4638
if "audio_caption_modulator_1" in key:
4739
key = key.replace("audio_caption_modulator_1", "audio_a2v_cross_attn_scale_shift_table")
48-
49-
# Handle audio_caption_projection
50-
# Checkpoint: audio_caption_projection.linear_1.weight
51-
# Flax: audio_caption_projection.linear_1.kernel
52-
# rename_key_and_reshape_tensor catches 'weight' -> 'kernel', but maybe something else renaming it?
53-
# No explicit rename needed if it's already linear_1/linear_2 unless name mismatch.
54-
55-
# Handle global norms (norm_out, audio_norm_out)
56-
# Checkpoint: norm_final -> norm_out (already handled)
57-
# Checkpoint also has audio_norm_final -> audio_norm_out?
5840
if "audio_norm_final" in key:
5941
key = key.replace("audio_norm_final", "audio_norm_out")
60-
61-
# Handle time_embed/audio_time_embed
62-
# Checkpoint: time_embed.emb.timestep_embedder.linear_1.weight
63-
# Flax: time_embed.emb.timestep_embedder.linear_1.kernel
64-
# If checkpoint uses different name structure?
65-
# time_embed.emb.timestep_embedder -> time_embed.emb.timestep_embedder (seems OK)
66-
67-
# Handle av_cross_attn...
68-
# These seem fine in name but verify if they are Linear or Conv? Linear.
69-
70-
71-
72-
# Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
73-
74-
# Handle audio_ff.net_0.proj -> audio_ff.net_0
75-
# Also handle ff.net_0.proj -> ff.net_0
7642
if ("audio_ff" in key or "ff" in key) and "proj" in key:
7743
key = key.replace(".proj", "")
78-
79-
# Handle to_out.0 -> to_out for LTX2Attention
80-
# rename_key changes to_out.0 -> to_out_0
8144
if "to_out_0" in key:
8245
key = key.replace("to_out_0", "to_out")
8346

@@ -110,39 +73,24 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
11073
pass
11174

11275
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
113-
114-
# Check if we got 'kernel' but expected 'scale' (common for scanned layers where shape check fails)
115-
# Also check 'weight' because rename_key might not have converted it to kernel if it wasn't a known Linear
11676
flax_key_str = [str(k) for k in flax_key]
11777

11878
if flax_key_str[-1] in ["kernel", "weight"]:
119-
# Try replacing with scale and check if it exists in random_flax_state_dict
12079
temp_key_str = flax_key_str[:-1] + ["scale"]
12180
temp_key = tuple(temp_key_str) # Tuple of strings
12281

12382
if temp_key in random_flax_state_dict:
12483
flax_key_str = temp_key_str
12584
pass
126-
127-
# RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
128-
# Fix scale_shift_table mapping if it got 'kernel' appended
12985
if "scale_shift_table" in flax_key_str:
130-
# if last is kernel/weight, remove it
13186
if flax_key_str[-1] in ["kernel", "weight"]:
13287
flax_key_str.pop()
13388

134-
# Handle audio_norm_out / norm_out bias mapping
135-
# If renamed to ('audio_norm_out', 'bias') matches ('audio_norm_out', 'bias') in random_flax_state_dict?
136-
# Yes. But if rename_key mapped it differently?
137-
# Ensure norm_out/audio_norm_out are preserved.
138-
139-
# Helper to replace last occurrence
14089
def replace_suffix(lst, old, new):
14190
if lst and lst[-1] == old:
14291
lst[-1] = new
14392
return lst
14493

145-
# LTX-2 uses to_q, to_k, to_v, to_out, NOT query, key, value, proj_attn
14694
if "transformer_blocks" in flax_key_str:
14795
if flax_key_str[-1] == "query":
14896
flax_key_str[-1] = "to_q"
@@ -157,71 +105,41 @@ def replace_suffix(lst, old, new):
157105

158106
flax_key = tuple(flax_key_str)
159107

160-
# Explicit fixes for LTX-2
161-
# 1. Linear Projection Weights: weight -> kernel
162-
# Keys: audio_caption_projection, caption_projection, audio_proj_in, proj_in, audio_proj_out, proj_out, time_embed, audio_time_embed
163-
# Also av_cross_attn lines.
164108
if flax_key[-1] == "weight":
165-
# Check if we should rename to kernel
166-
# Heuristic: if parent has 'linear' or 'proj' or 'time_embed' or matches known list
167109
parent = flax_key[-2] if len(flax_key) > 1 else ""
168110
grandparent = flax_key[-3] if len(flax_key) > 2 else ""
169111

170112
should_be_kernel = False
171113
if "linear" in parent or "proj" in parent or "proj" in grandparent:
172114
should_be_kernel = True
173115
if "time_embed" in flax_key[0] or "cross_attn" in flax_key[0]:
174-
if "linear" in parent or "emb" in parent: # time_embed.linear
116+
if "linear" in parent or "emb" in parent:
175117
should_be_kernel = True
176118

177-
# Exception: norm weights are scale, handled below
178119
if "norm" in parent:
179120
should_be_kernel = False
180121

181122
if should_be_kernel:
182123
flax_key = flax_key[:-1] + ("kernel",)
183124

184-
# 2. Norm Weights: weight -> scale
185-
# Keys: norm_k, norm_q, norm_out, audio_norm_out
186125
if flax_key[-1] == "weight":
187126
if "norm" in flax_key[-2] or "norm" in flax_key[0]:
188127
flax_key = flax_key[:-1] + ("scale",)
189128

190-
# 3. Audio/Video Attention specifics
191-
# Checkpoint: attn1.to_q.weight -> Flax: attn1.to_q.kernel
192-
# rename_key usually handles this if it sees 'Linear', but here it might miss if valid_prefixes check fails or it's just 'to_q'.
193-
# Force q/k/v/out projections to kernel
194129
if flax_key[-1] == "weight" and flax_key[-2] in ["to_q", "to_k", "to_v", "to_out"]:
195130
flax_key = flax_key[:-1] + ("kernel",)
196131

197-
# 4. Fix 'to_out.0' -> 'to_out' if it persisted (sometimes rename_key does to_out_0)
198-
# Actually my previous fix in rename_for_ltx2_transformer handles string replacement?
199-
# Let's double check tuple.
200-
# If we have ('to_out', '0', 'weight') -> ('to_out', 'kernel') ?
201-
# Flax expects just 'to_out'? No, Flax nnx.Linear is typically a leaf unless wrapped.
202-
# LTX2Attention defines: self.to_out = nnx.Linear(...)
203-
# So it expects ('to_out', 'kernel').
204-
# Checkpoint has: to_out.0.weight and to_out.0.bias
205-
# This implies it was a Sequential or List in PyTorch?
206-
# If we see '0' in key, drop it if it's 'to_out'.
207132
if len(flax_key) >= 2 and flax_key[-2] == "0" and flax_key[-3] == "to_out":
208-
# ('to_out', '0', 'kernel') -> ('to_out', 'kernel')
209133
flax_key = flax_key[:-3] + ("to_out", flax_key[-1])
210134

211-
# 5. Fix norm_k/q becoming scale
212-
# Already handled by rule 2?
213-
# Checkpoint: attn1.norm_k.weight
214-
# Rule 2: parent 'norm_k' has 'norm' -> becomes scale. Correct.
215-
216-
flax_key_str = [str(k) for k in flax_key] # Update str list for final check
135+
flax_key_str = [str(k) for k in flax_key]
217136
flax_key = _tuple_str_to_int(flax_key)
218137

219138
if scan_layers and block_index is not None:
220139
if "transformer_blocks" in flax_key:
221140
if flax_key in flax_state_dict:
222141
new_tensor = flax_state_dict[flax_key]
223142
else:
224-
# Initialize with correct shape (layers, ...)
225143
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape, dtype=flax_tensor.dtype)
226144

227145
new_tensor = new_tensor.at[block_index].set(flax_tensor)
@@ -235,8 +153,6 @@ def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
235153
"""
236154
index_file = "diffusion_pytorch_model.safetensors.index.json"
237155
tensors = {}
238-
239-
# Try to download index file
240156
try:
241157
index_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=index_file)
242158
with open(index_path, "r") as f:
@@ -324,8 +240,6 @@ def load_vae_weights(
324240
subfolder: str = "vae"
325241
):
326242
device = jax.local_devices(backend=device)[0]
327-
# VAE for LTX-2 is likely single file, but safe to use the helper if we wanted general robustness.
328-
# But `lightricks/LTX-2` VAE is single file.
329243

330244
filename = "diffusion_pytorch_model.safetensors"
331245
try:
@@ -365,7 +279,6 @@ def load_vae_weights(
365279
resnet_index = None
366280

367281
for i, part in enumerate(pt_tuple_key):
368-
# Check for name_N pattern
369282
if "_" in part and part.split("_")[-1].isdigit():
370283
name = "_".join(part.split("_")[:-1])
371284
idx = int(part.split("_")[-1])
@@ -375,7 +288,6 @@ def load_vae_weights(
375288
pt_list.append(str(idx))
376289
elif name == "upsamplers":
377290
pt_list.append("upsampler")
378-
# Skip the index 0 for upsampler as Flax uses singular non-list
379291
elif name in ["down_blocks", "up_blocks", "downsamplers"]:
380292
pt_list.append(name)
381293
pt_list.append(str(idx))
@@ -385,56 +297,42 @@ def load_vae_weights(
385297
pt_list.append("upsampler")
386298
elif part in ["conv1", "conv2", "conv"]:
387299
pt_list.append(part)
388-
# Inject 'conv' if it's not already there AND not just added
389300
if i + 1 < len(pt_tuple_key) and pt_tuple_key[i+1] == "conv":
390-
pass # already has conv
301+
pass
391302
elif pt_list[-1] == "conv":
392-
pass # already has conv
303+
pass
393304
elif len(pt_list) >= 2 and pt_list[-2] == "conv":
394305
pass
395306
elif part == "conv":
396307
pass
397308
else:
398-
# If part is conv1/conv2, append 'conv'
399309
pt_list.append("conv")
400310
else:
401311
pt_list.append(part)
402312

403313
pt_tuple_key = tuple(pt_list)
404314

405315
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
406-
# _tuple_str_to_int might not be needed if we already injected ints, but it's safe
407316
flax_key = _tuple_str_to_int(flax_key)
408317

409-
# Allow latents_mean/std
410-
411-
# DEBUG
412318
flax_key_str = [str(x) for x in flax_key]
413319
if "conv" in flax_key_str or "bias" in flax_key_str:
414-
# print(f"DEBUG: VAE Key Map: {pt_tuple_key} -> {flax_key}")
415320
pass
416321

417322
if resnet_index is not None:
418323
if flax_key in flax_state_dict:
419324
current_tensor = flax_state_dict[flax_key]
420325
else:
421-
# Initialize with correct shape from random_flax_state_dict
422-
# We must use STRING tuple for lookup in random_flax_state_dict
423326
str_flax_key = tuple([str(x) for x in flax_key])
424327

425328
if str_flax_key in random_flax_state_dict:
426329
target_shape = random_flax_state_dict[str_flax_key].shape
427330
current_tensor = jnp.zeros(target_shape, dtype=flax_tensor.dtype)
428331
else:
429-
# Fallback if key missing (shouldn't happen with correct mapping)
430-
# print(f"Warning: Key {str_flax_key} not found in random_flax_state_dict, cannot stack.")
431-
current_tensor = flax_tensor # Might fail shape check later
432-
433-
# Place the tensor at the correct index
434-
# flax_tensor is (..., C), target is (N_resnets, ..., C)
332+
current_tensor = flax_tensor
435333

436334
str_flax_key = tuple([str(x) for x in flax_key])
437-
if str_flax_key in random_flax_state_dict: # Only stack if we have a valid target
335+
if str_flax_key in random_flax_state_dict:
438336
current_tensor = current_tensor.at[resnet_index].set(flax_tensor)
439337
flax_state_dict[flax_key] = current_tensor
440338
else:

0 commit comments

Comments
 (0)