@@ -144,16 +144,6 @@ def replace_suffix(lst, old, new):
144144 flax_key = tuple (flax_key_str )
145145 flax_key = _tuple_str_to_int (flax_key )
146146
147- if "scale_shift_table" in str (flax_key ):
148- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (scale_shift_table)" )
149-
150- if "audio_caption_projection" in str (flax_key ):
151- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_caption_projection)" )
152- if "audio_time_embed" in str (flax_key ):
153- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_time_embed)" )
154-
155- return flax_key , flax_tensor
156-
157147 if scan_layers and block_index is not None :
158148 if "transformer_blocks" in flax_key :
159149 if flax_key in flax_state_dict :
@@ -165,23 +155,6 @@ def replace_suffix(lst, old, new):
165155 new_tensor = new_tensor .at [block_index ].set (flax_tensor )
166156 flax_tensor = new_tensor
167157
168- # DEBUG TRACE
169- if "audio_ff" in str (flax_key ) and "kernel" in str (flax_key ) and block_index == 18 :
170- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (Block 18)" )
171- if "to_out" in str (flax_key ) and "kernel" in str (flax_key ) and block_index == 18 and "attn1" in str (flax_key ):
172- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (Block 18 attn1)" )
173-
174- if "proj_in" in str (flax_key ):
175- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (proj_in)" )
176-
177- if "scale_shift_table" in str (flax_key ):
178- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (scale_shift_table)" )
179-
180- if "audio_caption_projection" in str (flax_key ):
181- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_caption_projection)" )
182- if "audio_time_embed" in str (flax_key ):
183- print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (audio_time_embed)" )
184-
185158 return flax_key , flax_tensor
186159
187160def load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device ):
@@ -248,57 +221,17 @@ def load_transformer_weights(
248221 for key in flattened_dict :
249222 string_tuple = tuple ([str (item ) for item in key ])
250223 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
251-
252- # DEBUG: Print keys to understand mapping
253- print ("DEBUG: Top 20 keys from Checkpoint (tensors):" )
254- for k in list (tensors .keys ())[:20 ]:
255- print (k )
256-
257- print ("DEBUG: NON-BLOCK keys in Checkpoint:" )
258- for k in tensors .keys ():
259- if "transformer_blocks" not in k :
260- print (k )
261-
262-
263- print ("\n DEBUG: Top 20 keys from Flax Model (eval_shapes):" )
264- for k in list (random_flax_state_dict .keys ())[:20 ]:
265- print (k )
266-
267- print ("\n DEBUG: Transformer Block keys from Flax Model (eval_shapes):" )
268- for k in list (random_flax_state_dict .keys ()):
269- k_str = str (k )
270- if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str ):
271- print (f"EVAL_SHAPE: { k } " )
272- if "proj_out" in k_str or "norm_out" in k_str :
273- print (f"EVAL_SHAPE GLOBAL: { k } " )
274-
275- # Search for norm in tensors
276- print ("\n DEBUG: Search 'norm' in checkpoint keys:" )
277- for k in tensors .keys ():
278- if "norm" in k and "transformer_blocks" not in k :
279- print (f"CKPT norm: { k } " )
280224
281225 for pt_key , tensor in tensors .items ():
282226 renamed_pt_key = rename_key (pt_key )
283227 renamed_pt_key = rename_for_ltx2_transformer (renamed_pt_key )
284228
285- # DEBUG: Check intermediate rename
286- if "audio_ff.net.0.proj" in pt_key :
287- pass
288-
289229 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
290230
291231 flax_key , flax_tensor = get_key_and_value (
292232 pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers , num_layers
293233 )
294234
295- # DEBUG: Trace proj_out
296- if "proj_out" in str (flax_key ) and "bias" in str (flax_key ):
297- print (f"DEBUG: Trace proj_out: { pt_key } -> { flax_key } " )
298- # Check if added to dict
299- # It acts global so it should be added below
300-
301-
302235 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
303236
304237 validate_flax_state_dict (eval_shapes , flax_state_dict )
@@ -337,10 +270,6 @@ def load_vae_weights(
337270 loaded_state_dict = torch .load (ckpt_path , map_location = "cpu" )
338271 for k , v in loaded_state_dict .items ():
339272 tensors [k ] = torch2jax (v )
340-
341- print ("\n DEBUG: Top 20 keys from VAE Checkpoint (tensors):" )
342- for k in list (tensors .keys ())[:20 ]:
343- print (k )
344273
345274 flax_state_dict = {}
346275 cpu = jax .local_devices (backend = "cpu" )[0 ]
@@ -375,16 +304,18 @@ def load_vae_weights(
375304 pt_list .append (part )
376305 elif part in ["conv1" , "conv2" , "conv" ]:
377306 pt_list .append (part )
378- # Only inject 'conv' if it's not already there
379- # Check if next part is 'conv'
307+ # Inject 'conv' if it's not already there AND not just added
380308 if i + 1 < len (pt_tuple_key ) and pt_tuple_key [i + 1 ] == "conv" :
381309 pass # already has conv
382- elif pt_list [- 2 ] == "conv" : # Check previous injection
383- pass # already injected conv in previous step (if part was conv1/conv2/conv)
384- # Also avoid injecting if part ITSELF is 'conv'
310+ elif pt_list [- 1 ] == "conv" :
311+ pass # already has conv
385312 elif part == "conv" :
313+ # It IS conv, so we appended it. Do we need another one?
314+ # If part is 'conv', we appended it.
315+ # The original logic skipped it. We kept it.
386316 pass
387317 else :
318+ # If part is conv1/conv2, append 'conv'
388319 pt_list .append ("conv" )
389320 else :
390321 pt_list .append (part )
0 commit comments