Skip to content

Commit b9b4392

Browse files
committed
refactor for weight loading
1 parent 9f2d778 commit b9b4392

3 files changed

Lines changed: 402 additions & 178 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 305 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,88 @@
44
from flax import nnx
55
from flax.traverse_util import unflatten_dict, flatten_dict
66
from maxdiffusion import max_logging
7-
from ..modeling_flax_pytorch_utils import validate_flax_state_dict
7+
from ..modeling_flax_pytorch_utils import validate_flax_state_dict, rename_key
88
from .ltx2_utils import load_sharded_checkpoint
99
from .ltx2_utils import (
1010
_tuple_str_to_int,
1111
LTX_2_0_VIDEO_VAE_RENAME_DICT,
12+
rename_for_ltx2_transformer,
13+
get_key_and_value,
14+
rename_for_ltx2_audio_vae,
15+
rename_for_ltx2_vocoder,
1216
)
17+
def load_ltx2_3_checkpoint(pretrained_model_name_or_path: str, subfolder: str, device: str, filename: str):
18+
"""Loads weights from a single safetensors file for LTX-2.3."""
19+
from huggingface_hub import hf_hub_download
20+
from safetensors import safe_open
21+
from ..modeling_flax_pytorch_utils import torch2jax
1322

23+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
24+
tensors = {}
25+
with safe_open(ckpt_path, framework="pt") as f:
26+
for k in f.keys():
27+
tensors[k] = torch2jax(f.get_tensor(k))
28+
return tensorsdef rename_for_ltx2_3_transformer(key):
29+
"""
30+
Renames Diffusers LTX-2.3 keys to MaxDiffusion Flax LTX-2.3 keys.
31+
"""
32+
key = key.replace("patchify_proj", "proj_in")
33+
key = key.replace("audio_patchify_proj", "audio_proj_in")
34+
key = key.replace("norm_final", "norm_out")
35+
if "adaLN_modulation_1" in key:
36+
key = key.replace("adaLN_modulation_1", "scale_shift_table")
37+
38+
if "caption_modulator_1" in key:
39+
key = key.replace("caption_modulator_1", "video_a2v_cross_attn_scale_shift_table")
40+
if "audio_caption_modulator_1" in key:
41+
key = key.replace("audio_caption_modulator_1", "audio_a2v_cross_attn_scale_shift_table")
42+
if "audio_norm_final" in key:
43+
key = key.replace("audio_norm_final", "audio_norm_out")
44+
if ("audio_ff" in key or "ff" in key) and "proj" in key:
45+
key = key.replace(".proj", "")
46+
if "to_out_0" in key:
47+
key = key.replace("to_out_0", "to_out")
48+
49+
# Add missing mappings
50+
key = key.replace("av_ca_video_scale_shift_adaln_single", "av_cross_attn_video_scale_shift")
51+
key = key.replace("av_ca_a2v_gate_adaln_single", "av_cross_attn_video_a2v_gate")
52+
key = key.replace("av_ca_audio_scale_shift_adaln_single", "av_cross_attn_audio_scale_shift")
53+
key = key.replace("av_ca_v2a_gate_adaln_single", "av_cross_attn_audio_v2a_gate")
54+
key = key.replace("scale_shift_table_a2v_ca_video", "video_a2v_cross_attn_scale_shift_table")
55+
key = key.replace("scale_shift_table_a2v_ca_audio", "audio_a2v_cross_attn_scale_shift_table")
56+
57+
# LTX-2.3 specific mappings
58+
# Handle substrings before they are replaced by shorter patterns below
59+
key = key.replace("audio_prompt_adaln_single", "audio_prompt_adaln")
60+
key = key.replace("prompt_adaln_single", "prompt_adaln")
61+
key = key.replace("audio_prompt_scale_shift_table", "audio_scale_shift_table")
62+
key = key.replace("prompt_scale_shift_table", "scale_shift_table")
63+
64+
if "prompt_adaln" in key:
65+
key = key.replace("prompt_adaln", "caption_projection")
66+
if "audio_prompt_adaln" in key:
67+
key = key.replace("audio_prompt_adaln", "audio_caption_projection")
68+
if "video_text_proj_in" in key:
69+
key = key.replace("video_text_proj_in", "feature_extractor.video_linear")
70+
if "audio_text_proj_in" in key:
71+
key = key.replace("audio_text_proj_in", "feature_extractor.audio_linear")
72+
73+
key = key.replace("k_norm", "norm_k")
74+
key = key.replace("q_norm", "norm_q")
75+
key = key.replace("adaln_single", "time_embed")
76+
return keydef rename_for_ltx2_3_vocoder(key):
77+
"""Renames Diffusers LTX-2.3 Vocoder keys to MaxDiffusion Flax keys."""
78+
key = key.replace("ups.", "upsamplers.")
79+
key = key.replace("resblocks.", "resblocks_")
80+
key = key.replace("conv_post", "conv_out")
81+
key = key.replace("conv_pre", "conv_in")
82+
key = key.replace("act_post", "act_out")
83+
84+
# LTX-2.3 specific mappings for Vocoder
85+
if "downsample" in key and "lowpass" not in key:
86+
key = key.replace("downsample", "downsample.lowpass")
87+
88+
return key
1489

1590

1691
LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = {
@@ -46,19 +121,246 @@
46121
"audio_embeddings_connector": "audio_connector",
47122
}
48123

49-
def load_connectors_weights(
124+
def load_and_segregate_ltx2_3_weights(pretrained_model_name_or_path: str, filename: str = "ltx-2.3-22b-dev.safetensors"):
125+
"""Loads the full LTX-2.3 file once and splits it into component-specific dictionaries."""
126+
tensors = load_ltx2_3_checkpoint(pretrained_model_name_or_path, "", "cpu", filename=filename)
127+
128+
segregated = {
129+
"transformer": {},
130+
"vae": {},
131+
"audio_vae": {},
132+
"connectors": {},
133+
"vocoder": {},
134+
}
135+
136+
for pt_key, tensor in tensors.items():
137+
if pt_key.startswith("model.diffusion_model."):
138+
segregated["transformer"][pt_key.replace("model.diffusion_model.", "")] = tensor
139+
elif pt_key.startswith("audio_vae."):
140+
segregated["audio_vae"][pt_key.replace("audio_vae.", "")] = tensor
141+
elif pt_key.startswith("vae."):
142+
segregated["vae"][pt_key] = tensor
143+
elif pt_key.startswith("vocoder."):
144+
segregated["vocoder"][pt_key.replace("vocoder.", "")] = tensor
145+
elif any(x in pt_key for x in ["connectors.", "video_embeddings_connector", "audio_embeddings_connector", "text_embedding_projection"]):
146+
segregated["connectors"][pt_key] = tensor
147+
148+
return segregated
149+
150+
151+
def load_transformer_weights_2_3(
152+
eval_shapes: dict,
153+
device: str,
154+
tensors: dict,
155+
num_layers: int = 48,
156+
scan_layers: bool = True,
157+
):
158+
device = jax.local_devices(backend=device)[0]
159+
max_logging.log(f"Load and port LTX-2.3 transformer on {device}")
160+
161+
with jax.default_device(device):
162+
flax_state_dict = {}
163+
cpu = jax.local_devices(backend="cpu")[0]
164+
flattened_dict = flatten_dict(eval_shapes)
165+
166+
random_flax_state_dict = {}
167+
for key in flattened_dict:
168+
random_flax_state_dict[tuple(str(item) for item in key)] = flattened_dict[key]
169+
170+
for pt_key, tensor in tensors.items():
171+
# Keys are already filtered and stripped of "model.diffusion_model." by load_and_segregate
172+
if pt_key.startswith("audio_embeddings_connector") or pt_key.startswith("video_embeddings_connector"):
173+
continue
174+
175+
renamed_pt_key = rename_key(pt_key)
176+
renamed_pt_key = rename_for_ltx2_3_transformer(renamed_pt_key)
177+
178+
pt_tuple_key = tuple(renamed_pt_key.split("."))
179+
180+
flax_key, flax_tensor = get_key_and_value(
181+
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
182+
)
183+
184+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
185+
186+
validate_flax_state_dict(eval_shapes, flax_state_dict)
187+
flax_state_dict = unflatten_dict(flax_state_dict)
188+
jax.clear_caches()
189+
return flax_state_dict
190+
191+
192+
def load_audio_vae_weights_2_3(
193+
eval_shapes: dict,
194+
device: str,
195+
tensors: dict,
196+
):
197+
flax_state_dict = {}
198+
cpu = jax.local_devices(backend="cpu")[0]
199+
200+
flattened_eval = flatten_dict(eval_shapes)
201+
202+
for pt_key, tensor in tensors.items():
203+
# Keys are already filtered and stripped of "audio_vae." by load_and_segregate
204+
key = rename_for_ltx2_audio_vae(pt_key)
205+
206+
if key.endswith(".kernel") and tensor.ndim == 4:
207+
tensor = tensor.transpose(2, 3, 1, 0)
208+
209+
flax_key = _tuple_str_to_int(key.split("."))
210+
211+
if "up_stages" in flax_key:
212+
up_stages_idx = flax_key.index("up_stages")
213+
if up_stages_idx + 1 < len(flax_key) and isinstance(flax_key[up_stages_idx + 1], int):
214+
flax_key_list = list(flax_key)
215+
flax_key_list[up_stages_idx + 1] = 2 - flax_key[up_stages_idx + 1]
216+
flax_key = tuple(flax_key_list)
217+
218+
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
219+
220+
filtered_eval_shapes = {
221+
k: v for k, v in flattened_eval.items() if not any("dropout" in str(x) or "rngs" in str(x) for x in k)
222+
}
223+
224+
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
225+
return unflatten_dict(flax_state_dict)
226+
227+
228+
def load_vae_weights_2_3(
229+
eval_shapes: dict,
230+
device: str,
231+
tensors: dict,
232+
):
233+
flax_state_dict = {}
234+
cpu = jax.local_devices(backend="cpu")[0]
235+
flattened_eval = flatten_dict(eval_shapes)
236+
237+
random_flax_state_dict = {}
238+
for key in flattened_eval:
239+
random_flax_state_dict[tuple(str(item) for item in key)] = flattened_eval[key]
240+
241+
for pt_key, tensor in tensors.items():
242+
# Remove 'vae.' prefix if present in safetensors but not in model
243+
if pt_key.startswith("vae."):
244+
pt_key = pt_key[len("vae."):]
245+
246+
renamed_pt_key = pt_key.replace("nin_shortcut", "conv_shortcut")
247+
248+
pt_tuple_key = tuple(renamed_pt_key.split("."))
249+
250+
pt_list = []
251+
resnet_index = None
252+
253+
for i, part in enumerate(pt_tuple_key):
254+
if "_" in part and part.split("_")[-1].isdigit():
255+
name = "_".join(part.split("_")[:-1])
256+
idx = int(part.split("_")[-1])
257+
258+
if name == "resnets" or name == "block":
259+
pt_list.append("resnets")
260+
resnet_index = idx
261+
elif name == "upsamplers":
262+
pt_list.append("upsampler")
263+
elif name in ["down_blocks", "up_blocks", "downsamplers"]:
264+
pt_list.append(name)
265+
pt_list.append(str(idx))
266+
else:
267+
pt_list.append(part)
268+
elif part == "upsampler":
269+
pt_list.append("upsampler")
270+
elif part in ["conv1", "conv2", "conv", "conv_in", "conv_out", "conv_shortcut"]:
271+
pt_list.append(part)
272+
if (
273+
part != "conv"
274+
and (i + 1 == len(pt_tuple_key) or pt_tuple_key[i + 1] != "conv")
275+
and (len(pt_list) < 2 or pt_list[-2] != "conv")
276+
):
277+
pt_list.append("conv")
278+
else:
279+
pt_list.append(part)
280+
281+
pt_tuple_key = tuple(pt_list)
282+
283+
from .ltx2_utils import rename_key_and_reshape_tensor
284+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
285+
286+
flax_key = _tuple_str_to_int(flax_key)
287+
288+
if resnet_index is not None:
289+
str_flax_key = tuple([str(x) for x in flax_key])
290+
if str_flax_key in random_flax_state_dict:
291+
if flax_key not in flax_state_dict:
292+
target_shape = random_flax_state_dict[str_flax_key].shape
293+
flax_state_dict[flax_key] = jnp.zeros(target_shape, dtype=flax_tensor.dtype)
294+
flax_state_dict[flax_key] = flax_state_dict[flax_key].at[resnet_index].set(flax_tensor)
295+
else:
296+
flax_state_dict[flax_key] = flax_tensor
297+
else:
298+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
299+
300+
filtered_eval_shapes = {
301+
k: v for k, v in flattened_eval.items() if not any("dropout" in str(x) or "rngs" in str(x) for x in k)
302+
}
303+
304+
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
305+
return unflatten_dict(flax_state_dict)
306+
307+
308+
def load_vocoder_weights_2_3(
309+
eval_shapes: dict,
310+
device: str,
311+
tensors: dict,
312+
):
313+
flax_state_dict = {}
314+
cpu = jax.local_devices(backend="cpu")[0]
315+
316+
for pt_key, tensor in tensors.items():
317+
# Keys are already filtered and stripped of "vocoder." by load_and_segregate
318+
key = rename_for_ltx2_3_vocoder(pt_key)
319+
320+
# Always apply LTX-2.3 specific replacement
321+
key = key.replace("resblocks_", "resnets.")
322+
323+
parts = key.split(".")
324+
325+
if parts[-1] == "weight":
326+
parts[-1] = "kernel"
327+
328+
flax_key = _tuple_str_to_int(parts)
329+
330+
# Skip filter keys as they are derived in NNX model
331+
if "filter" in flax_key:
332+
continue
333+
334+
if flax_key[-1] == "kernel":
335+
if "upsamplers" in flax_key:
336+
tensor = tensor.transpose(2, 0, 1)[::-1, :, :]
337+
else:
338+
tensor = tensor.transpose(2, 1, 0)
339+
340+
if "mel_stft" in flax_key and ("forward_basis" in flax_key or "inverse_basis" in flax_key):
341+
tensor = tensor.transpose(2, 1, 0)
342+
343+
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
344+
345+
validate_flax_state_dict(eval_shapes, flax_state_dict)
346+
return unflatten_dict(flax_state_dict)
347+
348+
349+
def load_connectors_weights_2_3(
50350
pretrained_model_name_or_path: str,
51351
eval_shapes: dict,
52352
device: str,
53353
hf_download: bool = True,
54354
subfolder: str = "",
55355
filename: str = None,
56356
is_ltx2_3: bool = False,
357+
tensors: dict = None,
57358
):
58359
device = jax.local_devices(backend=device)[0]
59360

60361
with jax.default_device(device):
61-
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device, filename=filename)
362+
if tensors is None:
363+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device, filename=filename)
62364
flax_state_dict = {}
63365
cpu = jax.local_devices(backend="cpu")[0]
64366
flattened_eval = flatten_dict(eval_shapes)

0 commit comments

Comments
 (0)