Skip to content

Commit 5d0e4a5

Browse files
committed
weight loading for vocoder
1 parent d485921 commit 5d0e4a5

2 files changed

Lines changed: 118 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,95 @@ def load_ltx2_vae(
406406
validate_flax_state_dict(eval_shapes, flax_state_dict)
407407
flax_state_dict = unflatten_dict(flax_state_dict)
408408
return flax_state_dict
409+
410+
411+
def load_ltx2_vocoder(
412+
pretrained_model_name_or_path: str,
413+
eval_shapes: dict,
414+
device: str,
415+
hf_download: bool = True,
416+
subfolder: str = "vocoder",
417+
):
418+
device = jax.local_devices(backend=device)[0]
419+
# Vocoder weights are usually in diffusion_pytorch_model.safetensors inside "vocoder" folder
420+
filename = "diffusion_pytorch_model.safetensors"
421+
422+
local_files = False
423+
if os.path.isdir(pretrained_model_name_or_path):
424+
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
425+
if os.path.isfile(ckpt_path):
426+
local_files = True
427+
428+
tensors = {}
429+
if hf_download and not local_files:
430+
try:
431+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
432+
except Exception as e:
433+
raise e
434+
elif local_files:
435+
# ckpt_path already set
436+
pass
437+
else:
438+
# If not hf_download and not local, we can't load unless path is direct file
439+
if os.path.isfile(pretrained_model_name_or_path):
440+
ckpt_path = pretrained_model_name_or_path
441+
else:
442+
# Maybe it's just the repo id and user expects download but hf_download=False?
443+
pass
444+
445+
max_logging.log(f"Load and port {pretrained_model_name_or_path} Vocoder from {ckpt_path}")
446+
447+
with safe_open(ckpt_path, framework="pt") as f:
448+
for k in f.keys():
449+
tensors[k] = torch2jax(f.get_tensor(k))
450+
451+
flax_state_dict = {}
452+
cpu = jax.local_devices(backend="cpu")[0]
453+
454+
# Flatten eval_shapes to find valid keys/shapes
455+
flattened_eval_shapes = flatten_dict(eval_shapes)
456+
random_flax_state_dict = {}
457+
for key in flattened_eval_shapes:
458+
string_tuple = tuple([str(item) for item in key])
459+
random_flax_state_dict[string_tuple] = flattened_eval_shapes[key]
460+
del flattened_eval_shapes
461+
462+
for pt_key, tensor in tensors.items():
463+
renamed_pt_key = pt_key
464+
465+
# Mapping for LTX2Vocoder
466+
# PyTorch (Diffusers likely) -> Flax LTX2Vocoder
467+
468+
# conv_in -> conv_in.conv (nnx.Conv doesn't usually nest .conv unless we use our wrapper)
469+
# But checking vocoder_ltx2.py, self.conv_in = nnx.Conv(...)
470+
# So key is conv_in.kernel or conv_in.weight -> conv_in.kernel
471+
472+
# Diffusers usually uses: "conv_in.weight", "conv_in.bias"
473+
474+
# If we use nnx.Conv directly:
475+
# conv_in.weight -> conv_in.kernel
476+
# conv_in.bias -> conv_in.bias
477+
478+
# Does modeling_flax_pytorch_utils.rename_key handle .weight -> .kernel? Yes usually.
479+
480+
# ups.X.conv.weight (in Diffusers) -> upsamplers.layers.X.kernel (in Flax nnx.ConvTranspose)
481+
renamed_pt_key = renamed_pt_key.replace("ups.", "upsamplers.layers.")
482+
483+
# resblocks.X.convs1.Y.weight -> resnets.layers.X.convs1.layers.Y.kernel
484+
renamed_pt_key = renamed_pt_key.replace("resblocks.", "resnets.layers.")
485+
renamed_pt_key = renamed_pt_key.replace("convs1.", "convs1.layers.")
486+
renamed_pt_key = renamed_pt_key.replace("convs2.", "convs2.layers.")
487+
488+
# conv_out -> conv_out
489+
490+
pt_tuple_key = tuple(renamed_pt_key.split("."))
491+
492+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers=False)
493+
flax_key = _tuple_str_to_int(flax_key)
494+
495+
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
496+
497+
validate_flax_state_dict(eval_shapes, flax_state_dict)
498+
flax_state_dict = unflatten_dict(flax_state_dict)
499+
return flax_state_dict
500+

src/maxdiffusion/tests/test_loading_ltx2.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from flax import nnx
77
from maxdiffusion.models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
88
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
9-
from maxdiffusion.models.ltx2.ltx2_utils import load_ltx2_transformer, load_ltx2_vae
9+
from maxdiffusion.models.ltx2.vocoder_ltx2 import LTX2Vocoder
10+
from maxdiffusion.models.ltx2.ltx2_utils import load_ltx2_transformer, load_ltx2_vae, load_ltx2_vocoder
1011

1112
class LTX2LoadingTest(unittest.TestCase):
1213
def test_loading(self):
@@ -82,8 +83,32 @@ def create_vae():
8283
self.assertEqual(abstract_vae.latent_channels, 128)
8384
# self.assertEqual(len(abstract_vae.encoder.down_blocks), 4) # nnx.List not len()able directly? depends on version
8485

86+
8587
print("VAE structure verified.")
8688

89+
def test_vocoder_loading(self):
90+
# Configuration for Lightricks/LTX-2 Vocoder
91+
def create_vocoder():
92+
rngs = nnx.Rngs(0)
93+
return LTX2Vocoder(
94+
in_channels=128,
95+
hidden_channels=1024,
96+
out_channels=2,
97+
upsample_kernel_sizes=(16, 15, 8, 4, 4),
98+
upsample_factors=(6, 5, 2, 2, 2),
99+
resnet_kernel_sizes=(3, 7, 11),
100+
resnet_dilations=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
101+
leaky_relu_negative_slope=0.1,
102+
output_sampling_rate=24000,
103+
rngs=rngs,
104+
dtype=jnp.float32,
105+
)
106+
107+
abstract_vocoder = nnx.eval_shape(create_vocoder)
108+
self.assertEqual(abstract_vocoder.out_channels, 2)
109+
print("Vocoder structure verified.")
110+
111+
87112

88113
if __name__ == "__main__":
89114
unittest.main()

0 commit comments

Comments
 (0)