@@ -253,7 +253,16 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
253253 w_diff = weights .get ("diff" , None )
254254 b_diff = weights .get ("diff_b" , None )
255255
256- if w_diff is not None : w_diff = np .array (w_diff )
256+ if w_diff is not None :
257+ w_diff = np .array (w_diff )
258+ # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed.
259+ if isinstance (module , nnx .Conv ):
260+ if w_diff .ndim == 5 :
261+ w_diff = w_diff .transpose ((2 ,3 ,4 ,1 ,0 ))
262+ elif w_diff .ndim == 4 :
263+ w_diff = w_diff .transpose ((2 ,3 ,1 ,0 ))
264+ elif isinstance (module , nnx .Linear ) and w_diff .ndim == 2 :
265+ w_diff = w_diff .transpose ((1 ,0 ))
257266 if b_diff is not None : b_diff = np .array (b_diff )
258267
259268 # Check for Bias existence
@@ -351,9 +360,14 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
351360 if stack_w_diff is None :
352361 stack_w_diff = np .zeros (module .kernel .shape , dtype = np .float32 )
353362 wd = np .array (w ["diff" ])
354- # Reshape if 1x1 conv diff (squeeze spatial dims if needed, or broadcast)
355- if is_conv and wd .ndim != 5 : wd = wd .reshape (1 , 1 , 1 , in_feat , out_feat )
356- elif is_linear and wd .ndim != 2 : wd = wd .reshape (in_feat , out_feat )
363+ # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed.
364+ if is_conv :
365+ if wd .ndim == 5 :
366+ wd = wd .transpose ((2 ,3 ,4 ,1 ,0 ))
367+ elif wd .ndim == 4 :
368+ wd = wd .transpose ((2 ,3 ,1 ,0 ))
369+ elif is_linear and wd .ndim == 2 :
370+ wd = wd .transpose ((1 ,0 ))
357371
358372 stack_w_diff [i ] = wd
359373 has_diff = True
0 commit comments