1919from .lora_pipeline import StableDiffusionLoraLoaderMixin
2020from ..models import lora_nnx
2121from .. import max_logging
22- from . import lora_conversion_utils , preprocess_wan_lora_dict
22+ from . import lora_conversion_utils
2323
2424
2525class Wan2_1NNXLoraLoader (LoRABaseMixin ):
@@ -53,7 +53,7 @@ def translate_fn(nnx_path_str):
5353 if hasattr (pipeline , "transformer" ) and transformer_weight_name :
5454 max_logging .log (f"Merging LoRA into transformer with rank={ rank } " )
5555 h_state_dict , _ = lora_loader .lora_state_dict (lora_model_path , weight_name = transformer_weight_name , ** kwargs )
56- h_state_dict = preprocess_wan_lora_dict (h_state_dict )
56+ h_state_dict = lora_conversion_utils . preprocess_wan_lora_dict (h_state_dict )
5757 merge_fn (pipeline .transformer , h_state_dict , rank , scale , translate_fn , dtype = dtype )
5858 else :
5959 max_logging .log ("transformer not found or no weight name provided for LoRA." )
@@ -94,7 +94,7 @@ def translate_fn(nnx_path_str: str):
9494 if hasattr (pipeline , "high_noise_transformer" ) and high_noise_weight_name :
9595 max_logging .log (f"Merging LoRA into high_noise_transformer with rank={ rank } " )
9696 h_state_dict , _ = lora_loader .lora_state_dict (lora_model_path , weight_name = high_noise_weight_name , ** kwargs )
97- h_state_dict = preprocess_wan_lora_dict (h_state_dict )
97+ h_state_dict = lora_conversion_utils . preprocess_wan_lora_dict (h_state_dict )
9898 merge_fn (pipeline .high_noise_transformer , h_state_dict , rank , scale , translate_fn , dtype = dtype )
9999 else :
100100 max_logging .log ("high_noise_transformer not found or no weight name provided for LoRA." )
@@ -103,7 +103,7 @@ def translate_fn(nnx_path_str: str):
103103 if hasattr (pipeline , "low_noise_transformer" ) and low_noise_weight_name :
104104 max_logging .log (f"Merging LoRA into low_noise_transformer with rank={ rank } " )
105105 l_state_dict , _ = lora_loader .lora_state_dict (lora_model_path , weight_name = low_noise_weight_name , ** kwargs )
106- l_state_dict = preprocess_wan_lora_dict (l_state_dict )
106+ l_state_dict = lora_conversion_utils . preprocess_wan_lora_dict (l_state_dict )
107107 merge_fn (pipeline .low_noise_transformer , l_state_dict , rank , scale , translate_fn , dtype = dtype )
108108 else :
109109 max_logging .log ("low_noise_transformer not found or no weight name provided for LoRA." )
0 commit comments