@@ -87,24 +87,26 @@ def rename_for_ltx2_transformer(key):
8787def get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers , num_layers = 48 ):
8888 block_index = None
8989
90- # Handle transformer_blocks_N produced by rename_key
90+ # Handle transformer_blocks_N (underscore) produced by rename_key
9191 if scan_layers and len (pt_tuple_key ) > 0 and "transformer_blocks_" in pt_tuple_key [0 ]:
9292 import re
9393 m = re .match (r"transformer_blocks_(\d+)" , pt_tuple_key [0 ])
9494 if m :
9595 block_index = int (m .group (1 ))
9696 # Map transformer_blocks_N -> transformer_blocks
9797 pt_tuple_key = ("transformer_blocks" ,) + pt_tuple_key [1 :]
98+
99+ # Handle transformer_blocks.N (dot) from original keys if rename_key didn't underscore it
100+ if scan_layers and len (pt_tuple_key ) > 1 and pt_tuple_key [0 ] == "transformer_blocks" and pt_tuple_key [1 ].isdigit ():
101+ block_index = int (pt_tuple_key [1 ])
102+ pt_tuple_key = ("transformer_blocks" ,) + pt_tuple_key [2 :]
98103
99104 if scan_layers :
100105 if "transformer_blocks" in pt_tuple_key :
101106 pass # Already handled above or matches standard format
102107
103- # Handle scale_shift_table keys - they are Params, not Linear layers, so no 'kernel' suffix needed
104- # We might have renamed them to scale_shift_table already in rename_for_ltx2_transformer
108+ # Handle scale_shift_table keys
105109 if "scale_shift_table" in pt_tuple_key [- 1 ] or "scale_shift_table" in pt_tuple_key :
106- # if we renamed it to ends with scale_shift_table, use it directly
107- # But rename_key_and_reshape might have added kernel?
108110 pass
109111
110112 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
@@ -117,28 +119,21 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
117119 temp_key_str = flax_key_str [:- 1 ] + ["scale" ]
118120 temp_key = tuple (temp_key_str ) # Tuple of strings
119121
120- # random_flax_state_dict keys are tuples of STRINGS
121122 if temp_key in random_flax_state_dict :
122123 flax_key_str = temp_key_str
123- # If we are mapping weight -> scale, ensure tensor is 1D?
124- # Linear weights are 2D (transposed). Scale weights are 1D.
125- # If input tensor was 1D, rename_key_and_reshape_tensor converts it to 1D?
126- # No, if it thought it was Linear, it might have transposed (if 2D) or whatever.
127- # But if it was originally 1D 'weight' (like LayerNorm), rename_key_and_reshape_tensor (Linear logic)
128- # checks `if pt_tuple_key[-1] == "weight"`.
129- # Linear logic: `pt_tensor = pt_tensor.T`.
130- # If 1D, T is same. So harmless for 1D.
131- pass
132124
133125 # RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
134- # flax_key_str = [str(k) for k in flax_key] # Already have it
135-
136126 # Fix scale_shift_table mapping if it got 'kernel' appended
137127 if "scale_shift_table" in flax_key_str :
138128 # if last is kernel/weight, remove it
139129 if flax_key_str [- 1 ] in ["kernel" , "weight" ]:
140130 flax_key_str .pop ()
141131
132+ # Handle audio_norm_out / norm_out bias mapping
133+ # If renamed to ('audio_norm_out', 'bias') matches ('audio_norm_out', 'bias') in random_flax_state_dict?
134+ # Yes. But if rename_key mapped it differently?
135+ # Ensure norm_out/audio_norm_out are preserved.
136+
142137 # Helper to replace last occurrence
143138 def replace_suffix (lst , old , new ):
144139 if lst and lst [- 1 ] == old :
@@ -154,10 +149,6 @@ def replace_suffix(lst, old, new):
154149 elif flax_key_str [- 1 ] == "value" :
155150 flax_key_str [- 1 ] = "to_v"
156151
157- # For to_out, rename_key_and_reshape_tensor might map to_out_0 -> proj_attn
158- # OR if we mapped `to_out_0` -> `to_out` manually, it keeps `to_out` but changes `weight` -> `kernel`
159- # We just want to ensure consistency.
160- # If it became `proj_attn`, revert it.
161152 if len (flax_key_str ) >= 2 and flax_key_str [- 2 ] == "proj_attn" :
162153 # proj_attn, kernel -> to_out, kernel
163154 flax_key_str [- 2 ] = "to_out"
0 commit comments