Skip to content

Commit dbc0e10

Browse files
committed
logs parsing made better
1 parent 9e15a35 commit dbc0e10

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def translate_fn(nnx_path_str):
5454
if hasattr(pipeline, "transformer") and transformer_weight_name:
5555
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
5656
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
57-
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
57+
# Filter state dict for transformer keys to avoid confusing warnings
58+
transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")}
59+
merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype)
5860
else:
5961
max_logging.log("transformer not found or no weight name provided for LoRA.")
6062

@@ -64,7 +66,9 @@ def translate_fn(nnx_path_str):
6466
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
6567

6668
if h_state_dict is not None:
67-
merge_fn(pipeline.connectors, h_state_dict, rank, scale, translate_fn, dtype=dtype)
69+
# Filter state dict for connector keys to avoid confusing warnings
70+
connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")}
71+
merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype)
6872
else:
6973
max_logging.log("Could not load LoRA state dict for connectors.")
7074

0 commit comments

Comments
 (0)