@@ -60,13 +60,13 @@ def rename_for_ltx2_3_transformer(key):
6060 # Handle substrings before they are replaced by shorter patterns below
6161 key = key .replace ("audio_prompt_adaln_single" , "audio_prompt_adaln" )
6262 key = key .replace ("prompt_adaln_single" , "prompt_adaln" )
63- key = key .replace ("audio_prompt_scale_shift_table" , "audio_scale_shift_table" )
64- key = key .replace ("prompt_scale_shift_table" , "scale_shift_table" )
63+ # key = key.replace("audio_prompt_scale_shift_table", "audio_scale_shift_table")
64+ # key = key.replace("prompt_scale_shift_table", "scale_shift_table")
6565
66- if "prompt_adaln" in key :
67- key = key .replace ("prompt_adaln" , "caption_projection" )
68- if "audio_prompt_adaln" in key :
69- key = key .replace ("audio_prompt_adaln" , "audio_caption_projection" )
66+ # if "prompt_adaln" in key:
67+ # key = key.replace("prompt_adaln", "caption_projection")
68+ # if "audio_prompt_adaln" in key:
69+ # key = key.replace("audio_prompt_adaln", "audio_caption_projection")
7070 if "video_text_proj_in" in key :
7171 key = key .replace ("video_text_proj_in" , "feature_extractor.video_linear" )
7272 if "audio_text_proj_in" in key :
@@ -398,8 +398,6 @@ def load_vocoder_weights_2_3(
398398
399399 flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
400400
401- print (f"DEBUG Vocoder eval_shapes keys: { list (flatten_dict (eval_shapes ).keys ())[:20 ]} " )
402- print (f"DEBUG Vocoder flax_state_dict keys: { list (flax_state_dict .keys ())[:20 ]} " )
403401 validate_flax_state_dict (eval_shapes , flax_state_dict )
404402 return unflatten_dict (flax_state_dict )
405403
@@ -481,8 +479,6 @@ def load_connectors_weights_2_3(
481479 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
482480 flax_state_dict [base_key ] = jax .device_put (stacked_tensor , device = cpu )
483481
484- print (f"DEBUG Connectors eval_shapes keys: { list (flattened_eval .keys ())[:20 ]} " )
485- print (f"DEBUG Connectors flax_state_dict keys: { list (flax_state_dict .keys ())[:20 ]} " )
486482 filtered_eval_shapes = {
487483 k : v for k , v in flattened_eval .items () if not any ("dropout" in str (x ) or "rngs" in str (x ) for x in k )
488484 }
0 commit comments