Skip to content

Commit 24ee4cc

Browse files
ruff
1 parent 1f2e65c commit 24ee4cc

3 files changed

Lines changed: 4 additions & 8 deletions

File tree

src/maxdiffusion/loaders/flux_lora_pipeline.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
from .lora_base import LoRABaseMixin
1717
from ..models.lora import LoRALinearLayer, BaseLoRALayer
1818
import jax.numpy as jnp
19-
from flax.traverse_util import flatten_dict, unflatten_dict
20-
from flax.core.frozen_dict import unfreeze
19+
from flax.traverse_util import flatten_dict
2120
from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
2221
from huggingface_hub.utils import validate_hf_hub_args
23-
from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor)
2422

2523

2624
class FluxLoraLoaderMixin(LoRABaseMixin):
@@ -109,7 +107,6 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
109107
revision = kwargs.pop("revision", None)
110108
subfolder = kwargs.pop("subfolder", None)
111109
weight_name = kwargs.pop("weight_name", None)
112-
unet_config = kwargs.pop("unet_config", None)
113110
use_safetensors = kwargs.pop("use_safetensors", None)
114111
resume_download = kwargs.pop("resume_download", False)
115112

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,13 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
265265
pt_tuple_key = tuple(renamed_pt_key.split("."))
266266
if "alpha" in pt_tuple_key:
267267
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel")
268-
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
268+
network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409
269269
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel")
270-
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
270+
network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409
271271
else:
272272
if pt_tuple_key[-2] == "up":
273273
rank = tensor.shape[1]
274-
transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype)
274+
transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) # noqa: C409
275275

276276
params["transformer"] = unflatten_dict(transformer_params)
277277

src/maxdiffusion/tests/generate_flux_smoke_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def download_blob(gcs_file, local_file):
2222
gcs_dir_arr = gcs_file.replace("gs://", "").split("/")
2323
storage_client = storage.Client()
2424
bucket = storage_client.get_bucket(gcs_dir_arr[0])
25-
blob_loc = "/".join(gcs_dir_arr[1:])
2625
blob = bucket.blob("/".join(gcs_dir_arr[1:]))
2726
blob.download_to_filename(local_file)
2827

0 commit comments

Comments
 (0)