@@ -278,61 +278,6 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
278278 return params , rank , network_alphas
279279
280280
281- def convert_flux_lora_pytorch_state_dict_to_flax (config , pt_state_dict , params , adapter_name ):
282- pt_state_dict = {k : v .float ().numpy () for k , v in pt_state_dict .items ()}
283- transformer_params = flatten_dict (unfreeze (params ["transformer" ]))
284- network_alphas = {}
285- rank = None
286- for pt_key , tensor in pt_state_dict .items ():
287- renamed_pt_key = rename_key (pt_key )
288- renamed_pt_key = renamed_pt_key .replace ("lora_unet_" , "" )
289- renamed_pt_key = renamed_pt_key .replace ("lora_down" , f"lora-{ adapter_name } .down" )
290- renamed_pt_key = renamed_pt_key .replace ("lora_up" , f"lora-{ adapter_name } .up" )
291-
292- if "double_blocks" in renamed_pt_key :
293- renamed_pt_key = renamed_pt_key .replace ("double_blocks." , "double_blocks_" )
294- renamed_pt_key = renamed_pt_key .replace ("processor.proj_lora1.down" , f"attn.i_proj.lora-{ adapter_name } .down" )
295- renamed_pt_key = renamed_pt_key .replace ("processor.proj_lora1.up" , f"attn.i_proj.lora-{ adapter_name } .up" )
296- renamed_pt_key = renamed_pt_key .replace ("processor.proj_lora2.down" , f"attn.e_proj.lora-{ adapter_name } .down" )
297- renamed_pt_key = renamed_pt_key .replace ("processor.proj_lora2.up" , f"attn.e_proj.lora-{ adapter_name } .up" )
298- renamed_pt_key = renamed_pt_key .replace ("processor.qkv_lora1.down" , f"attn.i_qkv.lora-{ adapter_name } .down" )
299- renamed_pt_key = renamed_pt_key .replace ("processor.qkv_lora1.up" , f"attn.i_qkv.lora-{ adapter_name } .up" )
300- renamed_pt_key = renamed_pt_key .replace ("processor.qkv_lora2.down" , f"attn.e_qkv.lora-{ adapter_name } .down" )
301- renamed_pt_key = renamed_pt_key .replace ("processor.qkv_lora2.up" , f"attn.e_qkv.lora-{ adapter_name } .up" )
302-
303- renamed_pt_key = renamed_pt_key .replace ("_img_attn_proj" , ".attn.i_proj" )
304- renamed_pt_key = renamed_pt_key .replace ("_img_attn_qkv" , ".attn.i_qkv" )
305- renamed_pt_key = renamed_pt_key .replace ("_img_mlp_0" , ".img_mlp.layers_0" )
306- renamed_pt_key = renamed_pt_key .replace ("_img_mlp_2" , ".img_mlp.layers_2" )
307- renamed_pt_key = renamed_pt_key .replace ("_img_mod_lin" , ".img_norm1.lin" )
308- renamed_pt_key = renamed_pt_key .replace ("_txt_attn_proj" , ".attn.e_proj" )
309- renamed_pt_key = renamed_pt_key .replace ("_txt_attn_qkv" , ".attn.e_qkv" )
310- renamed_pt_key = renamed_pt_key .replace ("_txt_mlp_0" , ".txt_mlp.layers_0" )
311- renamed_pt_key = renamed_pt_key .replace ("_txt_mlp_2" , ".txt_mlp.layers_2" )
312- renamed_pt_key = renamed_pt_key .replace ("_txt_mod_lin" , ".txt_norm1.lin" )
313- elif "single_blocks" in renamed_pt_key :
314- renamed_pt_key = renamed_pt_key .replace ("_linear1" , ".linear1" )
315- renamed_pt_key = renamed_pt_key .replace ("_linear2" , ".linear2" )
316- renamed_pt_key = renamed_pt_key .replace ("_modulation_lin" , ".norm.lin" )
317-
318- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
319-
320- pt_tuple_key = tuple (renamed_pt_key .split ("." ))
321- if "alpha" in pt_tuple_key :
322- pt_tuple_key = pt_tuple_key [:- 1 ] + (f"lora-{ adapter_name } " , "down" , "kernel" )
323- network_alphas [tuple ([* pt_tuple_key ])] = tensor .item () # noqa: C409
324- pt_tuple_key = pt_tuple_key [:- 1 ] + (f"lora-{ adapter_name } " , "up" , "kernel" )
325- network_alphas [tuple ([* pt_tuple_key ])] = tensor .item () # noqa: C409
326- else :
327- if pt_tuple_key [- 2 ] == "up" :
328- rank = tensor .shape [1 ]
329- transformer_params [tuple ([* pt_tuple_key ])] = jnp .asarray (tensor .T , dtype = config .weights_dtype ) # noqa: C409
330-
331- params ["transformer" ] = unflatten_dict (transformer_params )
332-
333- return params , rank , network_alphas
334-
335-
336281def convert_lora_pytorch_state_dict_to_flax (pt_state_dict , params , network_alphas , adapter_name ):
337282 # Step 1: Convert pytorch tensor to numpy
338283 # sometimes we load weights in bf16 and numpy doesn't support it
0 commit comments