2929
3030
3131@jax .jit
32- def _compute_and_add_single_jit (kernel , bias , down , up , scale , w_diff , b_diff , m_diff ):
32+ def _compute_and_add_single_jit (kernel , bias , down , up , scale , w_diff , b_diff ):
3333 """
3434 Applies LoRA + Weight Diff + Bias Diff on device.
3535 """
@@ -48,17 +48,11 @@ def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff, m
4848 if bias is not None and b_diff is not None :
4949 bias = bias + b_diff .astype (bias .dtype )
5050
51- # 4. Apply DoRA magnitude vector
52- if m_diff is not None :
53- kernel = kernel * m_diff .astype (kernel .dtype )
54-
5551 return kernel , bias
5652
5753
5854@jax .jit
59- def _compute_and_add_scanned_jit (
60- kernel , downs , ups , alphas , global_scale , w_diffs = None , b_diffs = None , bias = None , m_diff = None
61- ):
55+ def _compute_and_add_scanned_jit (kernel , downs , ups , alphas , global_scale , w_diffs = None , b_diffs = None , bias = None ):
6256 """
6357 Applies scanned LoRA + Diffs.
6458 """
@@ -80,14 +74,6 @@ def _compute_and_add_scanned_jit(
8074 if bias is not None and b_diffs is not None :
8175 bias = bias + b_diffs .astype (bias .dtype )
8276
83- # 4. Apply DoRA magnitude vector
84- if m_diff is not None :
85- # Reshape for broadcasting with kernel
86- # kernel shape can be (L, In, Out) or (L, H, W, In, Out)
87- # m_diff shape is (L, Out)
88- new_shape = [m_diff .shape [0 ]] + [1 ] * (kernel .ndim - 2 ) + [m_diff .shape [1 ]]
89- kernel = kernel * m_diff .reshape (new_shape ).astype (kernel .dtype )
90-
9177 return kernel , bias
9278
9379
@@ -135,14 +121,6 @@ def parse_lora_dict(state_dict, dtype):
135121 lora_params [key_base ]["diff" ] = _to_jax_array (v , dtype = dtype )
136122 continue
137123
138- # DoRA Magnitude (e.g., "layer.diff_m")
139- if k .endswith (".diff_m" ):
140- key_base = k [: - len (".diff_m" )]
141- if key_base not in lora_params :
142- lora_params [key_base ] = {}
143- lora_params [key_base ]["diff_m" ] = _to_jax_array (v , dtype = dtype )
144- continue
145-
146124 # Standard LoRA
147125 m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
148126 if not m :
@@ -205,7 +183,6 @@ def _merge_lora_layer(module, weights, scale):
205183 # Prepare Diff terms
206184 w_diff = weights .get ("diff" , None )
207185 b_diff = weights .get ("diff_b" , None )
208- m_diff = weights .get ("diff_m" , None )
209186
210187 if w_diff is not None :
211188 w_diff = np .array (w_diff )
@@ -219,8 +196,6 @@ def _merge_lora_layer(module, weights, scale):
219196 w_diff = w_diff .transpose ((1 , 0 ))
220197 if b_diff is not None :
221198 b_diff = np .array (b_diff )
222- if m_diff is not None :
223- m_diff = np .array (m_diff )
224199
225200 # If LoCON, compute delta and add to w_diff
226201 if is_conv_kxk_locon :
@@ -247,9 +222,9 @@ def _merge_lora_layer(module, weights, scale):
247222 bias_val = module .bias .value if module .bias is not None else None
248223
249224 # --- EXECUTE JIT UPDATE ---
250- if down_w is not None or w_diff is not None or b_diff is not None or m_diff is not None :
225+ if down_w is not None or w_diff is not None or b_diff is not None :
251226 new_kernel , new_bias = _compute_and_add_single_jit (
252- module .kernel .value , bias_val , down_w , up_w , current_scale , w_diff , b_diff , m_diff
227+ module .kernel .value , bias_val , down_w , up_w , current_scale , w_diff , b_diff
253228 )
254229
255230 module .kernel .value = new_kernel
@@ -404,7 +379,6 @@ def merge_lora_for_scanned(
404379 # Initialize as None, allocate only if found to save memory
405380 stack_w_diff = None
406381 stack_b_diff = None
407- stack_m_diff = None
408382
409383 has_lora = False
410384 has_diff = False
@@ -415,14 +389,6 @@ def merge_lora_for_scanned(
415389 matched_keys .add (lora_key )
416390 w = lora_params [lora_key ]
417391
418- # --- Fill DoRA Magnitude ---
419- if "m_diff" in w :
420- if stack_m_diff is None :
421- stack_m_diff = np .ones ((num_layers , out_feat ), dtype = np .float32 )
422- dm = np .array (w ["m_diff" ])
423- stack_m_diff [i ] = dm .flatten ()
424- has_diff = True
425-
426392 # --- Fill LoRA ---
427393 if "down" in w :
428394 d , u = np .array (w ["down" ]), np .array (w ["up" ])
@@ -494,7 +460,6 @@ def merge_lora_for_scanned(
494460 stack_w_diff ,
495461 stack_b_diff ,
496462 bias_val ,
497- stack_m_diff ,
498463 )
499464
500465 module .kernel .value = new_k
0 commit comments