Skip to content

Commit 0577bd1

Browse files
committed
LTX-2 LoRA
1 parent 3f4cfc3 commit 0577bd1

4 files changed

Lines changed: 166 additions & 0 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,13 @@ jit_initializers: True
103103
enable_single_replica_ckpt_restoring: False
104104
seed: 0
105105
audio_format: "s16"
106+
107+
# LoRA parameters
108+
enable_lora: False
109+
lora_config: {
110+
lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"],
111+
weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"],
112+
adapter_name: ["camera-control-dolly-in"],
113+
rank: [32]
114+
}
115+

src/maxdiffusion/generate_ltx2.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.api_core.exceptions import GoogleAPIError
2626
import flax
2727
from maxdiffusion.utils.export_utils import export_to_video_with_audio
28+
from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader
2829

2930

3031
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -118,6 +119,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
118119
checkpoint_loader = LTX2Checkpointer(config=config)
119120
pipeline, _, _ = checkpoint_loader.load_checkpoint()
120121

122+
# If LoRA is specified, inject layers and load weights.
123+
if (
124+
getattr(config, "enable_lora", False)
125+
and hasattr(config, "lora_config")
126+
and config.lora_config
127+
and config.lora_config.get("lora_model_name_or_path")
128+
):
129+
lora_loader = LTX2NNXLoraLoader()
130+
lora_config = config.lora_config
131+
paths = lora_config["lora_model_name_or_path"]
132+
weights = lora_config.get("weight_name", [None] * len(paths))
133+
scales = lora_config.get("scale", [1.0] * len(paths))
134+
ranks = lora_config.get("rank", [64] * len(paths))
135+
136+
for i in range(len(paths)):
137+
pipeline = lora_loader.load_lora_weights(
138+
pipeline,
139+
paths[i],
140+
transformer_weight_name=weights[i],
141+
rank=ranks[i],
142+
scale=scales[i],
143+
scan_layers=config.scan_layers,
144+
dtype=config.weights_dtype,
145+
)
146+
121147
pipeline.enable_vae_slicing()
122148
pipeline.enable_vae_tiling()
123149

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,72 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
703703
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
704704

705705
return None
706+
707+
708+
def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
709+
"""
710+
Translates LTX2 NNX path to Diffusers/LoRA keys.
711+
"""
712+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
713+
suffix_map = {
714+
# Self Attention (attn1)
715+
"attn1.to_q": "attn1.to_q",
716+
"attn1.to_k": "attn1.to_k",
717+
"attn1.to_v": "attn1.to_v",
718+
"attn1.to_out": "attn1.to_out.0",
719+
720+
# Audio Self Attention (audio_attn1)
721+
"audio_attn1.to_q": "audio_attn1.to_q",
722+
"audio_attn1.to_k": "audio_attn1.to_k",
723+
"audio_attn1.to_v": "audio_attn1.to_v",
724+
"audio_attn1.to_out": "audio_attn1.to_out.0",
725+
726+
# Audio Cross Attention (audio_attn2)
727+
"audio_attn2.to_q": "audio_attn2.to_q",
728+
"audio_attn2.to_k": "audio_attn2.to_k",
729+
"audio_attn2.to_v": "audio_attn2.to_v",
730+
"audio_attn2.to_out": "audio_attn2.to_out.0",
731+
732+
# Cross Attention (attn2)
733+
"attn2.to_q": "attn2.to_q",
734+
"attn2.to_k": "attn2.to_k",
735+
"attn2.to_v": "attn2.to_v",
736+
"attn2.to_out": "attn2.to_out.0",
737+
738+
# Audio to Video Cross Attention
739+
"audio_to_video_attn.to_q": "audio_to_video_attn.to_q",
740+
"audio_to_video_attn.to_k": "audio_to_video_attn.to_k",
741+
"audio_to_video_attn.to_v": "audio_to_video_attn.to_v",
742+
"audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0",
743+
744+
# Video to Audio Cross Attention
745+
"video_to_audio_attn.to_q": "video_to_audio_attn.to_q",
746+
"video_to_audio_attn.to_k": "video_to_audio_attn.to_k",
747+
"video_to_audio_attn.to_v": "video_to_audio_attn.to_v",
748+
"video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0",
749+
750+
# Feed Forward
751+
"ff.net_0": "ff.net.0.proj",
752+
"ff.net_2": "ff.net.2",
753+
754+
# Audio Feed Forward
755+
"audio_ff.net_0": "audio_ff.net.0.proj",
756+
"audio_ff.net_2": "audio_ff.net.2",
757+
}
758+
759+
# --- 3. Translation Logic ---
760+
if scan_layers:
761+
if nnx_path_str.startswith("blocks."):
762+
inner_suffix = nnx_path_str[len("blocks.") :]
763+
if inner_suffix in suffix_map:
764+
return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}"
765+
else:
766+
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
767+
if m:
768+
idx, inner_suffix = m.group(1), m.group(2)
769+
if inner_suffix in suffix_map:
770+
return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}"
771+
772+
return None
773+
774+
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""NNX-based LoRA loader for LTX2 models."""
16+
17+
from flax import nnx
18+
from .lora_base import LoRABaseMixin
19+
from .lora_pipeline import StableDiffusionLoraLoaderMixin
20+
from ..models import lora_nnx
21+
from .. import max_logging
22+
from . import lora_conversion_utils
23+
24+
25+
class LTX2NNXLoraLoader(LoRABaseMixin):
26+
"""
27+
Handles loading LoRA weights into NNX-based LTX2 model.
28+
Assumes LTX2 pipeline contains 'transformer'
29+
attributes that are NNX Modules.
30+
"""
31+
32+
def load_lora_weights(
33+
self,
34+
pipeline: nnx.Module,
35+
lora_model_path: str,
36+
transformer_weight_name: str,
37+
rank: int,
38+
scale: float = 1.0,
39+
scan_layers: bool = False,
40+
dtype: str = "float32",
41+
**kwargs,
42+
):
43+
"""
44+
Merges LoRA weights into the pipeline from a checkpoint.
45+
"""
46+
lora_loader = StableDiffusionLoraLoaderMixin()
47+
48+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
49+
50+
def translate_fn(nnx_path_str):
51+
return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
52+
53+
if hasattr(pipeline, "transformer") and transformer_weight_name:
54+
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
55+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56+
# We assume keys match the translation function output.
57+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
58+
else:
59+
max_logging.log("transformer not found or no weight name provided for LoRA.")
60+
61+
return pipeline

0 commit comments

Comments
 (0)