Skip to content

Commit 5dcca11

Browse files
committed
Fix
1 parent 1c01078 commit 5dcca11

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def __init__(
5050

5151
# Infer dimensions from the base layer
5252
# nnx.Linear stores weights in 'kernel' usually shaped (in_features, out_features)
53-
in_features, out_features = base_layer.kernel.shape
53+
k_shape = base_layer.kernel.shape
54+
in_features = k_shape[0]
55+
out_features = int(jnp.prod(jnp.array(k_shape[1:])))
5456

5557
# 1. Down Projection (A): Random Initialization
5658
# Projects inputs down to rank 'r'
@@ -73,6 +75,11 @@ def __call__(self, x):
7375
# Equation: (x @ A @ B) * scaling
7476
lora_out = (x @ self.A.value @ self.B.value) * self.scaling()
7577

78+
# If base layer kernel was >2D, its output might be >2D on feature axis.
79+
# We need to reshape lora_out to match base_out's trailing dimensions if they don't match.
80+
if len(self.base_layer.kernel.shape) > 2:
81+
lora_out = lora_out.reshape(x.shape[:-1] + self.base_layer.kernel.shape[1:])
82+
7683
# 3. Sum
7784
return base_out + lora_out
7885

0 commit comments

Comments
 (0)