File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -50,7 +50,9 @@ def __init__(
5050
5151 # Infer dimensions from the base layer
5252 # nnx.Linear stores weights in 'kernel' usually shaped (in_features, out_features)
53- in_features , out_features = base_layer .kernel .shape
53+ k_shape = base_layer .kernel .shape
54+ in_features = k_shape [0 ]
55+ out_features = int (jnp .prod (jnp .array (k_shape [1 :])))
5456
5557 # 1. Down Projection (A): Random Initialization
5658 # Projects inputs down to rank 'r'
@@ -73,6 +75,11 @@ def __call__(self, x):
7375 # Equation: (x @ A @ B) * scaling
7476 lora_out = (x @ self .A .value @ self .B .value ) * self .scaling ()
7577
78+ # If base layer kernel was >2D, its output might be >2D on feature axis.
79+ # We need to reshape lora_out to match base_out's trailing dimensions if they don't match.
80+ if len (self .base_layer .kernel .shape ) > 2 :
81+ lora_out = lora_out .reshape (x .shape [:- 1 ] + self .base_layer .kernel .shape [1 :])
82+
7683 # 3. Sum
7784 return base_out + lora_out
7885
You can’t perform that action at this time.
0 commit comments