Skip to content

Commit 57b61de

Browse files
committed
refactor to move ltx2.3 weights loading into different file
1 parent 23bb98e commit 57b61de

3 files changed

Lines changed: 73 additions & 8 deletions

File tree

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import numpy as np
4+
from flax import nnx
5+
from flax.traverse_util import unflatten_dict, flatten_dict
6+
from maxdiffusion.utils import max_logging
7+
from maxdiffusion.models.flax_pytorch_utils import (
8+
load_sharded_checkpoint,
9+
validate_flax_state_dict,
10+
)
11+
from maxdiffusion.models.ltx2.ltx2_utils import (
12+
_tuple_str_to_int,
13+
LTX_2_0_VIDEO_VAE_RENAME_DICT,
14+
)
15+
16+
LTX_2_3_VIDEO_VAE_RENAME_DICT = {
17+
**LTX_2_0_VIDEO_VAE_RENAME_DICT,
18+
# Decoder extra blocks
19+
"up_blocks.7": "up_blocks.3.upsamplers.0",
20+
"up_blocks.8": "up_blocks.3",
21+
}
22+
23+
LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = {
24+
"connectors.": "",
25+
"video_embeddings_connector": "video_connector",
26+
"audio_embeddings_connector": "audio_connector",
27+
"transformer_1d_blocks": "transformer_blocks",
28+
"text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in",
29+
"text_embedding_projection.video_aggregate_embed": "video_text_proj_in",
30+
"q_norm": "norm_q",
31+
"k_norm": "norm_k",
32+
}
33+
34+
def load_connectors_weights(
35+
pretrained_model_name_or_path: str,
36+
eval_shapes: dict,
37+
device: str,
38+
hf_download: bool = True,
39+
subfolder: str = "",
40+
filename: str = None,
41+
):
42+
device = jax.local_devices(backend=device)[0]
43+
max_logging.log(f"Load and port {pretrained_model_name_or_path} Connectors on {device}")
44+
45+
with jax.default_device(device):
46+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device, filename=filename)
47+
flax_state_dict = {}
48+
cpu = jax.local_devices(backend="cpu")[0]
49+
flattened_eval = flatten_dict(eval_shapes)
50+
51+
for pt_key, tensor in tensors.items():
52+
if not any(x in pt_key for x in ["connectors.", "video_embeddings_connector", "audio_embeddings_connector"]):
53+
continue
54+
55+
flax_key_str = pt_key
56+
for replace_key, rename_to in LTX_2_3_CONNECTORS_KEYS_RENAME_DICT.items():
57+
flax_key_str = flax_key_str.replace(replace_key, rename_to)
58+
59+
flax_key = _tuple_str_to_int(flax_key_str.split("."))
60+
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
61+
62+
filtered_eval_shapes = {
63+
k: v for k, v in flattened_eval.items() if not any("dropout" in str(x) or "rngs" in str(x) for x in k)
64+
}
65+
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
66+
return unflatten_dict(flax_state_dict)

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from safetensors import safe_open
2525
from flax.traverse_util import unflatten_dict, flatten_dict
2626
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
27+
from .ltx2_3_utils import LTX_2_3_VIDEO_VAE_RENAME_DICT
2728

2829

2930
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
@@ -54,12 +55,7 @@
5455
"per_channel_statistics.std-of-means": "latents_std",
5556
}
5657

57-
LTX_2_3_VIDEO_VAE_RENAME_DICT = {
58-
**LTX_2_0_VIDEO_VAE_RENAME_DICT,
59-
# Decoder extra blocks
60-
"up_blocks.7": "up_blocks.3.upsamplers.0",
61-
"up_blocks.8": "up_blocks.3",
62-
}
58+
6359

6460

6561
def _tuple_str_to_int(in_tuple):
@@ -532,3 +528,6 @@ def load_audio_vae_weights(
532528

533529
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
534530
return unflatten_dict(flax_state_dict)
531+
532+
533+

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder
3737
from ...models.ltx2.vocoder_bwe_ltx2 import LTX2VocoderWithBWE
3838
from ...models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
39+
from ...models.ltx2.ltx2_3_utils import load_connectors_weights
3940
from ...models.ltx2.ltx2_utils import (
4041
load_transformer_weights,
41-
load_connector_weights,
4242
load_vae_weights,
4343
load_audio_vae_weights,
4444
load_vocoder_weights,
@@ -352,7 +352,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
352352
params = state.to_pure_dict()
353353
state = dict(nnx.to_flat_state(state))
354354

355-
params = load_connector_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="connectors")
355+
params = load_connectors_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="")
356356
if hasattr(config, "weights_dtype"):
357357
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
358358

0 commit comments

Comments
 (0)