@@ -88,7 +88,7 @@ def load_lora_weights(
8888 return params , rank , network_alphas
8989
9090 @classmethod
91- def _get_lora_layer (cls , module_path , module , rank , network_alphas ):
91+ def _get_lora_layer (cls , module_path , module , rank , network_alphas , adapter_name ):
9292 is_conv = any ("conv" in str_ for str_ in module_path )
9393 network_alpha = network_alphas .get (module_path , None )
9494 if is_conv :
@@ -105,7 +105,7 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
105105 dtype = module .dtype ,
106106 weights_dtype = module .param_dtype ,
107107 precision = module .precision ,
108- name = "lora" ,
108+ name = f "lora- { adapter_name } " ,
109109 )
110110 else :
111111 lora_module = LoRALinearLayer (
@@ -115,39 +115,41 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
115115 dtype = module .dtype ,
116116 weights_dtype = module .param_dtype ,
117117 precision = module .precision ,
118- name = "lora" ,
118+ name = f "lora- { adapter_name } " ,
119119 )
120120 return lora_module
121121
122- def rename_for_interceptor (params_keys , network_alphas ):
122+ def rename_for_interceptor (params_keys , network_alphas , adapter_name ):
123123 new_params_keys = []
124124 new_network_alphas = {}
125+ lora_name = f"lora-{ adapter_name } "
125126 for layer_lora in params_keys :
126- if "lora" in layer_lora :
127- new_layer_lora = layer_lora [: layer_lora .index ("lora" )]
127+ if lora_name in layer_lora :
128+ new_layer_lora = layer_lora [: layer_lora .index (lora_name )]
128129 if new_layer_lora not in new_params_keys :
129130 new_params_keys .append (new_layer_lora )
130131 network_alpha = network_alphas [layer_lora ]
131132 new_network_alphas [new_layer_lora ] = network_alpha
132133 return new_params_keys , new_network_alphas
133134
134135 @classmethod
135- def make_lora_interceptor (cls , params , rank , network_alphas ):
136+ def make_lora_interceptor (cls , params , rank , network_alphas , adapter_name ):
136137 # Only unet interceptor supported for now.
137138 network_alphas_for_interceptor = {}
138139
139140 unet_lora_keys = flax .traverse_util .flatten_dict (params ["unet" ]).keys ()
140- lora_keys , unet_alphas = cls .rename_for_interceptor (unet_lora_keys , network_alphas )
141+ lora_keys , unet_alphas = cls .rename_for_interceptor (unet_lora_keys , network_alphas , adapter_name )
141142 network_alphas_for_interceptor .update (unet_alphas )
142143
143144 text_encoder_keys = flax .traverse_util .flatten_dict (params ["text_encoder" ]).keys ()
144- text_encoder_keys , text_encoder_alphas = cls .rename_for_interceptor (text_encoder_keys , network_alphas )
145+ text_encoder_keys , text_encoder_alphas = cls .rename_for_interceptor (text_encoder_keys , network_alphas , adapter_name )
145146 lora_keys .extend (text_encoder_keys )
146147 network_alphas_for_interceptor .update (text_encoder_alphas )
147-
148148 if "text_encoder_2" in params .keys ():
149149 text_encoder_2_keys = flax .traverse_util .flatten_dict (params ["text_encoder_2" ]).keys ()
150- text_encoder_2_keys , text_encoder_2_alphas = cls .rename_for_interceptor (text_encoder_2_keys , network_alphas )
150+ text_encoder_2_keys , text_encoder_2_alphas = cls .rename_for_interceptor (
151+ text_encoder_2_keys , network_alphas , adapter_name
152+ )
151153 lora_keys .extend (text_encoder_2_keys )
152154 network_alphas_for_interceptor .update (text_encoder_2_alphas )
153155
@@ -161,7 +163,7 @@ def _intercept(next_fn, args, kwargs, context):
161163 if context .method_name == "__call__" :
162164 module_path = context .module .path
163165 if module_path in lora_keys :
164- lora_layer = cls ._get_lora_layer (module_path , context .module , rank , network_alphas_for_interceptor )
166+ lora_layer = cls ._get_lora_layer (module_path , context .module , rank , network_alphas_for_interceptor , adapter_name )
165167 return lora_layer (h , * args , ** kwargs )
166168 return h
167169
@@ -290,5 +292,5 @@ def load_lora(cls, state_dict, network_alphas, params, adapter_name=None, _pipel
290292 `default_{i}` where i is the total number of adapters being loaded.
291293 """
292294 # Load the layers corresponding to Unet.
293- params , rank , network_alphas = convert_lora_pytorch_state_dict_to_flax (state_dict , params , network_alphas )
295+ params , rank , network_alphas = convert_lora_pytorch_state_dict_to_flax (state_dict , params , network_alphas , adapter_name )
294296 return params , rank , network_alphas
0 commit comments