@@ -121,21 +121,35 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
121121
122122 def rename_for_interceptor (params_keys , network_alphas ):
123123 new_params_keys = []
124+ new_network_alphas = {}
124125 for layer_lora in params_keys :
125126 if "lora" in layer_lora :
126127 new_layer_lora = layer_lora [: layer_lora .index ("lora" )]
127128 if new_layer_lora not in new_params_keys :
128129 new_params_keys .append (new_layer_lora )
129130 network_alpha = network_alphas [layer_lora ]
130- del network_alphas [layer_lora ]
131- network_alphas [new_layer_lora ] = network_alpha
132- return new_params_keys , network_alphas
131+ new_network_alphas [new_layer_lora ] = network_alpha
132+ return new_params_keys , new_network_alphas
133133
134134 @classmethod
135135 def make_lora_interceptor (cls , params , rank , network_alphas ):
136136 # Only unet interceptor supported for now.
137+ network_alphas_for_interceptor = {}
138+
137139 unet_lora_keys = flax .traverse_util .flatten_dict (params ["unet" ]).keys ()
138- unet_lora_keys , network_alphas = cls .rename_for_interceptor (unet_lora_keys , network_alphas )
140+ lora_keys , unet_alphas = cls .rename_for_interceptor (unet_lora_keys , network_alphas )
141+ network_alphas_for_interceptor .update (unet_alphas )
142+
143+ text_encoder_keys = flax .traverse_util .flatten_dict (params ["text_encoder" ]).keys ()
144+ text_encoder_keys , text_encoder_alphas = cls .rename_for_interceptor (text_encoder_keys , network_alphas )
145+ lora_keys .extend (text_encoder_keys )
146+ network_alphas_for_interceptor .update (text_encoder_alphas )
147+
148+ if "text_encoder_2" in params .keys ():
149+ text_encoder_2_keys = flax .traverse_util .flatten_dict (params ["text_encoder_2" ]).keys ()
150+ text_encoder_2_keys , text_encoder_2_alphas = cls .rename_for_interceptor (text_encoder_2_keys , network_alphas )
151+ lora_keys .extend (text_encoder_2_keys )
152+ network_alphas_for_interceptor .update (text_encoder_2_alphas )
139153
140154 def _intercept (next_fn , args , kwargs , context ):
141155 mod = context .module
@@ -146,8 +160,8 @@ def _intercept(next_fn, args, kwargs, context):
146160 h = next_fn (* args , ** kwargs )
147161 if context .method_name == "__call__" :
148162 module_path = context .module .path
149- if module_path in unet_lora_keys :
150- lora_layer = cls ._get_lora_layer (module_path , context .module , rank , network_alphas )
163+ if module_path in lora_keys :
164+ lora_layer = cls ._get_lora_layer (module_path , context .module , rank , network_alphas_for_interceptor )
151165 return lora_layer (h , * args , ** kwargs )
152166 return h
153167
0 commit comments