1616
1717from typing import Union , Tuple , Optional
1818import re
19+ import torch
20+ import torch .utils .dlpack
21+ from jax import dlpack
1922import jax .numpy as jnp
2023from flax import nnx
2124from .. import max_logging
@@ -200,6 +203,11 @@ def inject_lora(
200203
201204 return model
202205
206+ def _to_jax_array (v ):
207+ if isinstance (v , torch .Tensor ):
208+ return dlpack .from_dlpack (torch .utils .dlpack .to_dlpack (v ))
209+ return jnp .array (v )
210+
203211def merge_lora (model : nnx .Module , state_dict : dict , scale : float , translate_fn = None ):
204212 """
205213 Merges weights from a Diffusers-formatted state dict directly
@@ -213,7 +221,7 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
213221 key_base = k [:- len (".alpha" )]
214222 if key_base not in lora_params :
215223 lora_params [key_base ] = {}
216- lora_params [key_base ]["alpha" ] = jnp . array (v )
224+ lora_params [key_base ]["alpha" ] = _to_jax_array (v )
217225 continue
218226
219227 m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
@@ -232,7 +240,7 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
232240 continue
233241 if key_base not in lora_params :
234242 lora_params [key_base ] = {}
235- lora_params [key_base ][weight_type ] = jnp . array (v )
243+ lora_params [key_base ][weight_type ] = _to_jax_array (v )
236244 max_logging .log (f"Parsed { len (lora_params )} unique LoRA module keys." )
237245
238246 assigned_count = 0
@@ -289,7 +297,7 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
289297 key_base = k [:- len (".alpha" )]
290298 if key_base not in lora_params :
291299 lora_params [key_base ] = {}
292- lora_params [key_base ]["alpha" ] = jnp . array (v )
300+ lora_params [key_base ]["alpha" ] = _to_jax_array (v )
293301 continue
294302
295303 m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
@@ -308,7 +316,7 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
308316 continue
309317 if key_base not in lora_params :
310318 lora_params [key_base ] = {}
311- lora_params [key_base ][weight_type ] = jnp . array (v )
319+ lora_params [key_base ][weight_type ] = _to_jax_array (v )
312320 max_logging .log (f"Parsed { len (lora_params )} unique LoRA module keys for scanned merge." )
313321
314322 assigned_count = 0
0 commit comments