1+ # Copyright 2025 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # https://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ from typing import Union , Dict
16+ from .lora_base import LoRABaseMixin
17+ from ..models .lora import LoRALinearLayer , BaseLoRALayer
18+ import jax .numpy as jnp
19+ from flax .traverse_util import flatten_dict , unflatten_dict
20+ from flax .core .frozen_dict import unfreeze
21+ from ..models .modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
22+ from huggingface_hub .utils import validate_hf_hub_args
23+ from maxdiffusion .models .modeling_flax_pytorch_utils import (rename_key , rename_key_and_reshape_tensor )
24+ class FluxLoraLoaderMixin (LoRABaseMixin ):
25+
26+ _lora_lodable_modules = ["transformer" , "text_encoder" ]
27+
28+ def load_lora_weights (
29+ self ,
30+ config ,
31+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , jnp .ndarray ]],
32+ params ,
33+ adapter_name = None ,
34+ ** kwargs
35+ ):
36+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
37+
38+ params , rank , network_alphas = self .load_lora (
39+ config ,
40+ state_dict ,
41+ params = params ,
42+ adapter_name = adapter_name ,
43+ )
44+
45+ return params , rank , network_alphas
46+
47+ def rename_for_interceptor (params_keys , network_alphas , adapter_name ):
48+ new_params_keys = []
49+ new_network_alphas = {}
50+ lora_name = f"lora-{ adapter_name } "
51+ for layer_lora in params_keys :
52+ if lora_name in layer_lora :
53+ new_layer_lora = layer_lora [: layer_lora .index (lora_name )]
54+ if new_layer_lora not in new_params_keys :
55+ new_params_keys .append (new_layer_lora )
56+ network_alpha = network_alphas [layer_lora ]
57+ new_network_alphas [new_layer_lora ] = network_alpha
58+ return new_params_keys , new_network_alphas
59+
60+ @classmethod
61+ def make_lora_interceptor (cls , params , rank , network_alphas , adapter_name ):
62+ network_alphas_for_interceptor = {}
63+
64+ transformer_keys = flatten_dict (params ["transformer" ]).keys ()
65+ lora_keys , transformer_alphas = cls .rename_for_interceptor (transformer_keys , network_alphas , adapter_name )
66+ network_alphas_for_interceptor .update (transformer_alphas )
67+
68+ def _intercept (next_fn , args , kwargs , context ):
69+ mod = context .module
70+ while mod is not None :
71+ if isinstance (mod , BaseLoRALayer ):
72+ return next_fn (* args , ** kwargs )
73+ mod = mod .parent
74+ h = next_fn (* args , ** kwargs )
75+ if context .method_name == "__call__" :
76+ module_path = context .module .path
77+ if module_path in lora_keys :
78+ lora_layer = cls ._get_lora_layer (module_path , context .module , rank , network_alphas_for_interceptor , adapter_name )
79+ return lora_layer (h , * args , ** kwargs )
80+ return h
81+
82+ return _intercept
83+
84+ @classmethod
85+ def _get_lora_layer (cls , module_path , module , rank , network_alphas , adapter_name ):
86+ network_alpha = network_alphas .get (module_path , None )
87+ lora_module = LoRALinearLayer (
88+ out_features = module .features ,
89+ rank = rank ,
90+ network_alpha = network_alpha ,
91+ dtype = module .dtype ,
92+ weights_dtype = module .param_dtype ,
93+ precision = module .precision ,
94+ name = f"lora-{ adapter_name } " ,
95+ )
96+ return lora_module
97+
98+ @classmethod
99+ @validate_hf_hub_args
100+ def lora_state_dict (cls , pretrained_model_name_or_path : str , ** kwargs ):
101+
102+ cache_dir = kwargs .pop ("cache_dir" , None )
103+ force_download = kwargs .pop ("force_download" , False )
104+ proxies = kwargs .pop ("proxies" , None )
105+ local_files_only = kwargs .pop ("local_files_only" , None )
106+ use_auth_token = kwargs .pop ("use_auth_token" , None )
107+ revision = kwargs .pop ("revision" , None )
108+ subfolder = kwargs .pop ("subfolder" , None )
109+ weight_name = kwargs .pop ("weight_name" , None )
110+ unet_config = kwargs .pop ("unet_config" , None )
111+ use_safetensors = kwargs .pop ("use_safetensors" , None )
112+ resume_download = kwargs .pop ("resume_download" , False )
113+
114+ allow_pickle = False
115+ if use_safetensors is None :
116+ use_safetensors = True
117+ allow_pickle = True
118+
119+ user_agent = {
120+ "file_type" : "attn_procs_weights" ,
121+ "framework" : "pytorch" ,
122+ }
123+
124+ state_dict = cls ._fetch_state_dict (
125+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path ,
126+ weight_name = weight_name ,
127+ use_safetensors = use_safetensors ,
128+ local_files_only = local_files_only ,
129+ cache_dir = cache_dir ,
130+ force_download = force_download ,
131+ resume_download = resume_download ,
132+ proxies = proxies ,
133+ use_auth_token = use_auth_token ,
134+ revision = revision ,
135+ subfolder = subfolder ,
136+ user_agent = user_agent ,
137+ allow_pickle = allow_pickle ,
138+ )
139+
140+ return state_dict
141+
142+ @classmethod
143+ def load_lora (cls , config , state_dict , params , adapter_name = None ):
144+ params , rank , network_alphas = convert_flux_lora_pytorch_state_dict_to_flax (config , state_dict , params , adapter_name )
145+ return params , rank , network_alphas
0 commit comments