@@ -52,79 +52,22 @@ def load_lora_weights(
5252
5353 # Handle high noise model
5454 if hasattr (pipeline , "high_noise_transformer" ) and high_noise_weight_name :
55- max_logging .log (f"Injecting LoRA into high_noise_transformer with rank={ rank } " )
56- lora_nnx .inject_lora (
57- pipeline .high_noise_transformer , rank = rank , scale = scale , rngs = nnx .Rngs (rng ), target_linear = True , target_conv = True
58- )
59- h_state_dict , h_alphas = lora_loader .lora_state_dict (
55+ max_logging .log (f"Merging LoRA into high_noise_transformer with rank={ rank } " )
56+ h_state_dict , _ = lora_loader .lora_state_dict (
6057 lora_model_path , weight_name = high_noise_weight_name , ** kwargs
6158 )
62- self . _assign_weights_to_nnx_model (pipeline .high_noise_transformer , h_state_dict , h_alphas if h_alphas else {} )
59+ lora_nnx . merge_lora (pipeline .high_noise_transformer , h_state_dict , scale )
6360 else :
6461 max_logging .warning ("high_noise_transformer not found or no weight name provided for LoRA." )
6562
6663 # Handle low noise model
6764 if hasattr (pipeline , "low_noise_transformer" ) and low_noise_weight_name :
68- max_logging .log (f"Injecting LoRA into low_noise_transformer with rank={ rank } " )
69- lora_nnx .inject_lora (
70- pipeline .low_noise_transformer , rank = rank , scale = scale , rngs = nnx .Rngs (rng ), target_linear = True , target_conv = True
71- )
72- l_state_dict , l_alphas = lora_loader .lora_state_dict (
65+ max_logging .log (f"Merging LoRA into low_noise_transformer with rank={ rank } " )
66+ l_state_dict , _ = lora_loader .lora_state_dict (
7367 lora_model_path , weight_name = low_noise_weight_name , ** kwargs
7468 )
75- self . _assign_weights_to_nnx_model (pipeline .low_noise_transformer , l_state_dict , l_alphas if l_alphas else {} )
69+ lora_nnx . merge_lora (pipeline .low_noise_transformer , l_state_dict , scale )
7670 else :
7771 max_logging .warning ("low_noise_transformer not found or no weight name provided for LoRA." )
7872
7973 return pipeline
80-
81- def _assign_weights_to_nnx_model (self , model : nnx .Module , state_dict : dict , network_alphas : dict ):
82- """
83- Assigns weights from a Diffusers-formatted state dict to
84- injected LoRALinear/LoRAConv layers in an NNX model.
85- """
86- lora_params = {}
87- for k , v in state_dict .items ():
88- m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
89- if not m :
90- m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
91-
92- if m :
93- module_path_str , weight_type = m .group (1 ), m .group (2 )
94- if module_path_str not in lora_params :
95- lora_params [module_path_str ] = {}
96- lora_params [module_path_str ][weight_type ] = jnp .array (v )
97- else :
98- max_logging .warning (f"Could not parse LoRA key: { k } " )
99-
100- assigned_count = 0
101- for path , submodule in nnx .iter_graph (model ):
102- if isinstance (submodule , (lora_nnx .LoRALinear , lora_nnx .LoRAConv )):
103- nnx_path_str = "." .join (map (str , path ))
104-
105- matched_key = None
106- if nnx_path_str in lora_params :
107- matched_key = nnx_path_str
108- else :
109- # Fallback: check if any param key matches end of nnx path
110- for k in lora_params :
111- if nnx_path_str .endswith (k ) or k .endswith (nnx_path_str ):
112- matched_key = k
113- break
114-
115- if matched_key and matched_key in lora_params :
116- weights = lora_params [matched_key ]
117- if "down" in weights and "up" in weights :
118- if isinstance (submodule , lora_nnx .LoRALinear ):
119- submodule .A .value = weights ["down" ].T
120- submodule .B .value = weights ["up" ].T
121- assigned_count += 1
122- elif isinstance (submodule , lora_nnx .LoRAConv ):
123- submodule .down .kernel .value = weights ["down" ]
124- submodule .up .kernel .value = weights ["up" ]
125- assigned_count += 1
126-
127- pass
128- else :
129- max_logging .warning (f"LoRA weights for { matched_key } incomplete." )
130- max_logging .log (f"Assigned weights to { assigned_count } LoRA layers in { type (model )} ." )
0 commit comments