Skip to content

Commit ac542bc

Browse files
committed
Fix
1 parent 5dcca11 commit ac542bc

2 files changed

Lines changed: 69 additions & 64 deletions

File tree

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -52,79 +52,22 @@ def load_lora_weights(
5252

5353
# Handle high noise model
5454
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
55-
max_logging.log(f"Injecting LoRA into high_noise_transformer with rank={rank}")
56-
lora_nnx.inject_lora(
57-
pipeline.high_noise_transformer, rank=rank, scale=scale, rngs=nnx.Rngs(rng), target_linear=True, target_conv=True
58-
)
59-
h_state_dict, h_alphas = lora_loader.lora_state_dict(
55+
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
56+
h_state_dict, _ = lora_loader.lora_state_dict(
6057
lora_model_path, weight_name=high_noise_weight_name, **kwargs
6158
)
62-
self._assign_weights_to_nnx_model(pipeline.high_noise_transformer, h_state_dict, h_alphas if h_alphas else {})
59+
lora_nnx.merge_lora(pipeline.high_noise_transformer, h_state_dict, scale)
6360
else:
6461
max_logging.warning("high_noise_transformer not found or no weight name provided for LoRA.")
6562

6663
# Handle low noise model
6764
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
68-
max_logging.log(f"Injecting LoRA into low_noise_transformer with rank={rank}")
69-
lora_nnx.inject_lora(
70-
pipeline.low_noise_transformer, rank=rank, scale=scale, rngs=nnx.Rngs(rng), target_linear=True, target_conv=True
71-
)
72-
l_state_dict, l_alphas = lora_loader.lora_state_dict(
65+
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
66+
l_state_dict, _ = lora_loader.lora_state_dict(
7367
lora_model_path, weight_name=low_noise_weight_name, **kwargs
7468
)
75-
self._assign_weights_to_nnx_model(pipeline.low_noise_transformer, l_state_dict, l_alphas if l_alphas else {})
69+
lora_nnx.merge_lora(pipeline.low_noise_transformer, l_state_dict, scale)
7670
else:
7771
max_logging.warning("low_noise_transformer not found or no weight name provided for LoRA.")
7872

7973
return pipeline
80-
81-
def _assign_weights_to_nnx_model(self, model: nnx.Module, state_dict: dict, network_alphas: dict):
82-
"""
83-
Assigns weights from a Diffusers-formatted state dict to
84-
injected LoRALinear/LoRAConv layers in an NNX model.
85-
"""
86-
lora_params = {}
87-
for k, v in state_dict.items():
88-
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
89-
if not m:
90-
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
91-
92-
if m:
93-
module_path_str, weight_type = m.group(1), m.group(2)
94-
if module_path_str not in lora_params:
95-
lora_params[module_path_str] = {}
96-
lora_params[module_path_str][weight_type] = jnp.array(v)
97-
else:
98-
max_logging.warning(f"Could not parse LoRA key: {k}")
99-
100-
assigned_count = 0
101-
for path, submodule in nnx.iter_graph(model):
102-
if isinstance(submodule, (lora_nnx.LoRALinear, lora_nnx.LoRAConv)):
103-
nnx_path_str = ".".join(map(str, path))
104-
105-
matched_key = None
106-
if nnx_path_str in lora_params:
107-
matched_key = nnx_path_str
108-
else:
109-
# Fallback: check if any param key matches end of nnx path
110-
for k in lora_params:
111-
if nnx_path_str.endswith(k) or k.endswith(nnx_path_str):
112-
matched_key = k
113-
break
114-
115-
if matched_key and matched_key in lora_params:
116-
weights = lora_params[matched_key]
117-
if "down" in weights and "up" in weights:
118-
if isinstance(submodule, lora_nnx.LoRALinear):
119-
submodule.A.value = weights["down"].T
120-
submodule.B.value = weights["up"].T
121-
assigned_count +=1
122-
elif isinstance(submodule, lora_nnx.LoRAConv):
123-
submodule.down.kernel.value = weights["down"]
124-
submodule.up.kernel.value = weights["up"]
125-
assigned_count += 1
126-
127-
pass
128-
else:
129-
max_logging.warning(f"LoRA weights for {matched_key} incomplete.")
130-
max_logging.log(f"Assigned weights to {assigned_count} LoRA layers in {type(model)}.")

src/maxdiffusion/models/lora_nnx.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"""
1616

1717
from typing import Union, Tuple, Optional
18+
import re
1819
import jax.numpy as jnp
1920
from flax import nnx
21+
from .. import max_logging
2022

2123
class BaseLoRALayer(nnx.Module):
2224
"""
@@ -196,4 +198,64 @@ def inject_lora(
196198
wrapper = LoRAConv(base_layer=module, rank=rank, scale=scale, rngs=rngs)
197199
setattr(parent, attr_name, wrapper)
198200

199-
return model
201+
return model
202+
203+
def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
204+
"""
205+
Merges weights from a Diffusers-formatted state dict directly
206+
into the kernel of nnx.Linear and nnx.Conv layers.
207+
"""
208+
lora_params = {}
209+
for k, v in state_dict.items():
210+
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
211+
if not m:
212+
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
213+
214+
if m:
215+
module_path_str, weight_type = m.group(1), m.group(2)
216+
if module_path_str not in lora_params:
217+
lora_params[module_path_str] = {}
218+
lora_params[module_path_str][weight_type] = jnp.array(v)
219+
else:
220+
max_logging.warning(f"Could not parse LoRA key: {k}")
221+
222+
assigned_count = 0
223+
for path, module in nnx.iter_graph(model):
224+
nnx_path_str = ".".join(map(str, path))
225+
226+
matched_key = None
227+
if nnx_path_str in lora_params:
228+
matched_key = nnx_path_str
229+
else:
230+
# Fallback: check if any param key matches end of nnx path
231+
for k in lora_params:
232+
if nnx_path_str.endswith(k):
233+
matched_key = k
234+
break
235+
236+
if matched_key and matched_key in lora_params:
237+
weights = lora_params[matched_key]
238+
if "down" in weights and "up" in weights:
239+
if isinstance(module, nnx.Linear):
240+
down_w = weights["down"] # (rank, in_features)
241+
up_w = weights["up"] # (out_features_flat, rank)
242+
# delta = A@B = down.T @ up.T
243+
delta = (down_w.T @ up_w.T).reshape(module.kernel.shape)
244+
module.kernel.value += delta * scale
245+
assigned_count +=1
246+
elif isinstance(module, nnx.Conv):
247+
if module.kernel_size == (1, 1):
248+
down_w = weights["down"] # (1,1,in_c,rank)
249+
up_w = weights["up"] # (1,1,rank,out_c)
250+
# delta = down @ up for channel dimension
251+
delta = (jnp.squeeze(down_w) @ jnp.squeeze(up_w)).reshape(module.kernel.shape)
252+
module.kernel.value += delta * scale
253+
assigned_count += 1
254+
else:
255+
raise NotImplementedError(
256+
f"Merging LoRA weights for Conv layer {matched_key} "
257+
f"with kernel_size {module.kernel_size} > 1 is not supported."
258+
)
259+
else:
260+
max_logging.warning(f"LoRA weights for {matched_key} incomplete.")
261+
max_logging.log(f"Merged weights into {assigned_count} layers in {type(model).__name__}.")

0 commit comments

Comments
 (0)