Skip to content

Commit 9e15a35

Browse files
committed
distilled lora connector key fix
1 parent 6603e0a commit 9e15a35

2 files changed

Lines changed: 14 additions & 1 deletion

File tree

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,9 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
786786
"caption_projection.linear_2": "diffusion_model.caption_projection.linear_2",
787787
"audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1",
788788
"audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2",
789+
790+
# Connectors
791+
"feature_extractor.linear": "text_embedding_projection.aggregate_embed",
789792
}
790793

791794
if nnx_path_str in global_map:

src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,22 @@ def load_lora_weights(
5050
def translate_fn(nnx_path_str):
5151
return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
5252

53+
h_state_dict = None
5354
if hasattr(pipeline, "transformer") and transformer_weight_name:
5455
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
5556
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56-
# We assume keys match the translation function output.
5757
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
5858
else:
5959
max_logging.log("transformer not found or no weight name provided for LoRA.")
6060

61+
if hasattr(pipeline, "connectors"):
62+
max_logging.log(f"Merging LoRA into connectors with rank={rank}")
63+
if h_state_dict is None and transformer_weight_name:
64+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
65+
66+
if h_state_dict is not None:
67+
merge_fn(pipeline.connectors, h_state_dict, rank, scale, translate_fn, dtype=dtype)
68+
else:
69+
max_logging.log("Could not load LoRA state dict for connectors.")
70+
6171
return pipeline

0 commit comments

Comments
 (0)