@@ -54,7 +54,9 @@ def translate_fn(nnx_path_str):
5454 if hasattr (pipeline , "transformer" ) and transformer_weight_name :
5555 max_logging .log (f"Merging LoRA into transformer with rank={ rank } " )
5656 h_state_dict , _ = lora_loader .lora_state_dict (lora_model_path , weight_name = transformer_weight_name , ** kwargs )
57- merge_fn (pipeline .transformer , h_state_dict , rank , scale , translate_fn , dtype = dtype )
57+ # Filter state dict for transformer keys to avoid confusing warnings
58+ transformer_state_dict = {k : v for k , v in h_state_dict .items () if k .startswith ("diffusion_model" )}
59+ merge_fn (pipeline .transformer , transformer_state_dict , rank , scale , translate_fn , dtype = dtype )
5860 else :
5961 max_logging .log ("transformer not found or no weight name provided for LoRA." )
6062
@@ -64,7 +66,9 @@ def translate_fn(nnx_path_str):
6466 h_state_dict , _ = lora_loader .lora_state_dict (lora_model_path , weight_name = transformer_weight_name , ** kwargs )
6567
6668 if h_state_dict is not None :
67- merge_fn (pipeline .connectors , h_state_dict , rank , scale , translate_fn , dtype = dtype )
69+ # Filter state dict for connector keys to avoid confusing warnings
70+ connector_state_dict = {k : v for k , v in h_state_dict .items () if k .startswith ("text_embedding_projection" )}
71+ merge_fn (pipeline .connectors , connector_state_dict , rank , scale , translate_fn , dtype = dtype )
6872 else :
6973 max_logging .log ("Could not load LoRA state dict for connectors." )
7074
0 commit comments