@@ -96,6 +96,12 @@ def rename_for_ltx2_3_vocoder(key):
9696 "model.diffusion_model." : "" ,
9797 "connectors." : "" ,
9898 "transformer_1d_blocks" : "stacked_blocks" ,
99+ "video_embeddings_connector" : "video_connector" ,
100+ "audio_embeddings_connector" : "audio_connector" ,
101+ "ff.net.0.proj.weight" : "ff.net_0.kernel" ,
102+ "ff.net.0.proj.bias" : "ff.net_0.bias" ,
103+ "ff.net.2.weight" : "ff.net_2.kernel" ,
104+ "ff.net.2.bias" : "ff.net_2.bias" ,
99105 "text_embedding_projection.audio_aggregate_embed.weight" : "audio_text_proj_in.kernel" ,
100106 "text_embedding_projection.audio_aggregate_embed.bias" : "audio_text_proj_in.bias" ,
101107 "text_embedding_projection.video_aggregate_embed.weight" : "video_text_proj_in.kernel" ,
@@ -392,6 +398,8 @@ def load_vocoder_weights_2_3(
392398
393399 flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
394400
401+ print (f"DEBUG Vocoder eval_shapes keys: { list (flatten_dict (eval_shapes ).keys ())[:20 ]} " )
402+ print (f"DEBUG Vocoder flax_state_dict keys: { list (flax_state_dict .keys ())[:20 ]} " )
395403 validate_flax_state_dict (eval_shapes , flax_state_dict )
396404 return unflatten_dict (flax_state_dict )
397405
0 commit comments