-
Notifications
You must be signed in to change notification settings - Fork 69
Flux lora #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Flux lora #148
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
d5ac715
add support for flux vae. ~ wip
jfacevedo-google 394ebd1
test for flux vae both encoding and decoding.
jfacevedo-google 025642b
add clip text encoder test
jfacevedo-google a2b7f82
remove transformers inside maxdiffusion, add transformers dependency.…
jfacevedo-google 2b83d5c
add double block to flux
jfacevedo-google 37d9f00
forward pass for single double block.
jfacevedo-google 8785d00
trying to use scan.
jfacevedo-google cb91d5e
add single stream block
jfacevedo-google bb71982
finish transformer
jfacevedo-google 3eb5729
convert pt weights to flax and load transformer state.
jfacevedo-google 956341e
apply fsdp sharding, do one forward pass in the transformer.
jfacevedo-google 4b64f5d
wip - generate fn
jfacevedo-google 860e76e
working loop, bad generation
jfacevedo-google 93a3bb6
e2e, encoder offloading.
jfacevedo-google 601f40c
add missing conversions of pt to jax weights.
jfacevedo-google d16c020
support both dev and schnell loading. Images still incorrect.
jfacevedo-google 4a12b39
flux schnell working
jfacevedo-google 9871c7d
removed unused code.
jfacevedo-google a75a125
support dev
jfacevedo-google 05b6fc8
add sentencepiece requirement
jfacevedo-google df25e47
fix repeated double and single blocks.
jfacevedo-google 587bc6a
optimized flash block sizes for trillium.
jfacevedo-google 8905362
Merge branch 'main' into flux_impl
jfacevedo-google b87443f
clean up code and lint
jfacevedo-google 37df8b9
fix sdxl generate smoke tests.
jfacevedo-google e56825f
fix rest of unit tests.
jfacevedo-google 064a3a7
update readme and some dependencies.
entrpn fa1c23b
remove unused dependencies.
entrpn b4d0502
initial lora implementation for flux
jfacevedo-google 9e07358
adding another format lora support.
jfacevedo-google 4c68d53
Merge branch 'main' into flux_lora
jfacevedo-google 1f2e65c
Support other format loras. update readme. Run code_style.
jfacevedo-google 24ee4cc
ruff
jfacevedo-google File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Union, Dict | ||
| from .lora_base import LoRABaseMixin | ||
| from ..models.lora import LoRALinearLayer, BaseLoRALayer | ||
| import jax.numpy as jnp | ||
| from flax.traverse_util import flatten_dict | ||
| from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax | ||
| from huggingface_hub.utils import validate_hf_hub_args | ||
|
|
||
|
|
||
| class FluxLoraLoaderMixin(LoRABaseMixin): | ||
|
|
||
| _lora_lodable_modules = ["transformer", "text_encoder"] | ||
|
|
||
| def load_lora_weights( | ||
| self, | ||
| config, | ||
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], | ||
| params, | ||
| adapter_name=None, | ||
| **kwargs, | ||
| ): | ||
| state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
|
|
||
| params, rank, network_alphas = self.load_lora( | ||
| config, | ||
| state_dict, | ||
| params=params, | ||
| adapter_name=adapter_name, | ||
| ) | ||
|
|
||
| return params, rank, network_alphas | ||
|
|
||
| def rename_for_interceptor(params_keys, network_alphas, adapter_name): | ||
| new_params_keys = [] | ||
| new_network_alphas = {} | ||
| lora_name = f"lora-{adapter_name}" | ||
| for layer_lora in params_keys: | ||
| if lora_name in layer_lora: | ||
| new_layer_lora = layer_lora[: layer_lora.index(lora_name)] | ||
| if new_layer_lora not in new_params_keys: | ||
| new_params_keys.append(new_layer_lora) | ||
| network_alpha = network_alphas.get(layer_lora, None) | ||
| new_network_alphas[new_layer_lora] = network_alpha | ||
| return new_params_keys, new_network_alphas | ||
|
|
||
| @classmethod | ||
| def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): | ||
| network_alphas_for_interceptor = {} | ||
|
|
||
| transformer_keys = flatten_dict(params["transformer"]).keys() | ||
| lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name) | ||
| network_alphas_for_interceptor.update(transformer_alphas) | ||
|
|
||
| def _intercept(next_fn, args, kwargs, context): | ||
| mod = context.module | ||
| while mod is not None: | ||
| if isinstance(mod, BaseLoRALayer): | ||
| return next_fn(*args, **kwargs) | ||
| mod = mod.parent | ||
| h = next_fn(*args, **kwargs) | ||
| if context.method_name == "__call__": | ||
| module_path = context.module.path | ||
| if module_path in lora_keys: | ||
| lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name) | ||
| return lora_layer(h, *args, **kwargs) | ||
| return h | ||
|
|
||
| return _intercept | ||
|
|
||
| @classmethod | ||
| def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name): | ||
| network_alpha = network_alphas.get(module_path, None) | ||
| lora_module = LoRALinearLayer( | ||
| out_features=module.features, | ||
| rank=rank, | ||
| network_alpha=network_alpha, | ||
| dtype=module.dtype, | ||
| weights_dtype=module.param_dtype, | ||
| precision=module.precision, | ||
| name=f"lora-{adapter_name}", | ||
| ) | ||
| return lora_module | ||
|
|
||
| @classmethod | ||
| @validate_hf_hub_args | ||
| def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): | ||
|
|
||
| cache_dir = kwargs.pop("cache_dir", None) | ||
| force_download = kwargs.pop("force_download", False) | ||
| proxies = kwargs.pop("proxies", None) | ||
| local_files_only = kwargs.pop("local_files_only", None) | ||
| use_auth_token = kwargs.pop("use_auth_token", None) | ||
| revision = kwargs.pop("revision", None) | ||
| subfolder = kwargs.pop("subfolder", None) | ||
| weight_name = kwargs.pop("weight_name", None) | ||
| use_safetensors = kwargs.pop("use_safetensors", None) | ||
| resume_download = kwargs.pop("resume_download", False) | ||
|
|
||
| allow_pickle = False | ||
| if use_safetensors is None: | ||
| use_safetensors = True | ||
| allow_pickle = True | ||
|
|
||
| user_agent = { | ||
| "file_type": "attn_procs_weights", | ||
| "framework": "pytorch", | ||
| } | ||
|
|
||
| state_dict = cls._fetch_state_dict( | ||
| pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path, | ||
| weight_name=weight_name, | ||
| use_safetensors=use_safetensors, | ||
| local_files_only=local_files_only, | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| resume_download=resume_download, | ||
| proxies=proxies, | ||
| use_auth_token=use_auth_token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| allow_pickle=allow_pickle, | ||
| ) | ||
|
|
||
| return state_dict | ||
|
|
||
| @classmethod | ||
| def load_lora(cls, config, state_dict, params, adapter_name=None): | ||
| params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name) | ||
| return params, rank, network_alphas |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Why not use maybe_load_lora() from maxdiffusion_utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe_use_lora() in maxdiffusion_utils is specific to sdxl and won't work with flux. I should rename that method to maybe_load_sdxl_lora. I will create an issue to track this and add it on a different commit. Thanks for the review.