Skip to content

Commit 0b40036

Browse files
committed
Fix
1 parent d4e22a4 commit 0b40036

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,16 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
253253
w_diff = weights.get("diff", None)
254254
b_diff = weights.get("diff_b", None)
255255

256-
if w_diff is not None: w_diff = np.array(w_diff)
256+
if w_diff is not None:
257+
w_diff = np.array(w_diff)
258+
# Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed.
259+
if isinstance(module, nnx.Conv):
260+
if w_diff.ndim == 5:
261+
w_diff = w_diff.transpose((2,3,4,1,0))
262+
elif w_diff.ndim == 4:
263+
w_diff = w_diff.transpose((2,3,1,0))
264+
elif isinstance(module, nnx.Linear) and w_diff.ndim == 2:
265+
w_diff = w_diff.transpose((1,0))
257266
if b_diff is not None: b_diff = np.array(b_diff)
258267

259268
# Check for Bias existence
@@ -351,9 +360,14 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
351360
if stack_w_diff is None:
352361
stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32)
353362
wd = np.array(w["diff"])
354-
# Reshape if 1x1 conv diff (squeeze spatial dims if needed, or broadcast)
355-
if is_conv and wd.ndim != 5: wd = wd.reshape(1, 1, 1, in_feat, out_feat)
356-
elif is_linear and wd.ndim != 2: wd = wd.reshape(in_feat, out_feat)
363+
# Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed.
364+
if is_conv:
365+
if wd.ndim == 5:
366+
wd = wd.transpose((2,3,4,1,0))
367+
elif wd.ndim == 4:
368+
wd = wd.transpose((2,3,1,0))
369+
elif is_linear and wd.ndim == 2:
370+
wd = wd.transpose((1,0))
357371

358372
stack_w_diff[i] = wd
359373
has_diff = True

0 commit comments

Comments
 (0)