Skip to content

Commit 7c881e5

Browse files
committed
fix
1 parent 513cf1a commit 7c881e5

1 file changed

Lines changed: 12 additions & 21 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -87,24 +87,26 @@ def rename_for_ltx2_transformer(key):
8787
def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers=48):
8888
block_index = None
8989

90-
# Handle transformer_blocks_N produced by rename_key
90+
# Handle transformer_blocks_N (underscore) produced by rename_key
9191
if scan_layers and len(pt_tuple_key) > 0 and "transformer_blocks_" in pt_tuple_key[0]:
9292
import re
9393
m = re.match(r"transformer_blocks_(\d+)", pt_tuple_key[0])
9494
if m:
9595
block_index = int(m.group(1))
9696
# Map transformer_blocks_N -> transformer_blocks
9797
pt_tuple_key = ("transformer_blocks",) + pt_tuple_key[1:]
98+
99+
# Handle transformer_blocks.N (dot) from original keys if rename_key didn't underscore it
100+
if scan_layers and len(pt_tuple_key) > 1 and pt_tuple_key[0] == "transformer_blocks" and pt_tuple_key[1].isdigit():
101+
block_index = int(pt_tuple_key[1])
102+
pt_tuple_key = ("transformer_blocks",) + pt_tuple_key[2:]
98103

99104
if scan_layers:
100105
if "transformer_blocks" in pt_tuple_key:
101106
pass # Already handled above or matches standard format
102107

103-
# Handle scale_shift_table keys - they are Params, not Linear layers, so no 'kernel' suffix needed
104-
# We might have renamed them to scale_shift_table already in rename_for_ltx2_transformer
108+
# Handle scale_shift_table keys
105109
if "scale_shift_table" in pt_tuple_key[-1] or "scale_shift_table" in pt_tuple_key:
106-
# if we renamed it to ends with scale_shift_table, use it directly
107-
# But rename_key_and_reshape might have added kernel?
108110
pass
109111

110112
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
@@ -117,28 +119,21 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
117119
temp_key_str = flax_key_str[:-1] + ["scale"]
118120
temp_key = tuple(temp_key_str) # Tuple of strings
119121

120-
# random_flax_state_dict keys are tuples of STRINGS
121122
if temp_key in random_flax_state_dict:
122123
flax_key_str = temp_key_str
123-
# If we are mapping weight -> scale, ensure tensor is 1D?
124-
# Linear weights are 2D (transposed). Scale weights are 1D.
125-
# If input tensor was 1D, rename_key_and_reshape_tensor converts it to 1D?
126-
# No, if it thought it was Linear, it might have transposed (if 2D) or whatever.
127-
# But if it was originally 1D 'weight' (like LayerNorm), rename_key_and_reshape_tensor (Linear logic)
128-
# checks `if pt_tuple_key[-1] == "weight"`.
129-
# Linear logic: `pt_tensor = pt_tensor.T`.
130-
# If 1D, T is same. So harmless for 1D.
131-
pass
132124

133125
# RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
134-
# flax_key_str = [str(k) for k in flax_key] # Already have it
135-
136126
# Fix scale_shift_table mapping if it got 'kernel' appended
137127
if "scale_shift_table" in flax_key_str:
138128
# if last is kernel/weight, remove it
139129
if flax_key_str[-1] in ["kernel", "weight"]:
140130
flax_key_str.pop()
141131

132+
# Handle audio_norm_out / norm_out bias mapping
133+
# If renamed to ('audio_norm_out', 'bias') matches ('audio_norm_out', 'bias') in random_flax_state_dict?
134+
# Yes. But if rename_key mapped it differently?
135+
# Ensure norm_out/audio_norm_out are preserved.
136+
142137
# Helper to replace last occurrence
143138
def replace_suffix(lst, old, new):
144139
if lst and lst[-1] == old:
@@ -154,10 +149,6 @@ def replace_suffix(lst, old, new):
154149
elif flax_key_str[-1] == "value":
155150
flax_key_str[-1] = "to_v"
156151

157-
# For to_out, rename_key_and_reshape_tensor might map to_out_0 -> proj_attn
158-
# OR if we mapped `to_out_0` -> `to_out` manually, it keeps `to_out` but changes `weight` -> `kernel`
159-
# We just want to ensure consistency.
160-
# If it became `proj_attn`, revert it.
161152
if len(flax_key_str) >= 2 and flax_key_str[-2] == "proj_attn":
162153
# proj_attn, kernel -> to_out, kernel
163154
flax_key_str[-2] = "to_out"

0 commit comments

Comments
 (0)