Skip to content

Commit 66e9140

Browse files
committed
vocoder weight
1 parent 99a1e1e commit 66e9140

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ def load_vocoder_weights(
382382
if filename and pt_key.startswith("vocoder."):
383383
pt_key = pt_key[len("vocoder."):]
384384
key = rename_for_ltx2_vocoder(pt_key)
385+
if filename == "ltx-2.3-22b-dev.safetensors":
386+
key = key.replace("resblocks_", "resnets.")
385387
parts = key.split(".")
386388

387389
if parts[-1] == "weight":
@@ -398,6 +400,9 @@ def load_vocoder_weights(
398400
tensor = tensor.transpose(2, 0, 1)[::-1, :, :]
399401
else:
400402
tensor = tensor.transpose(2, 1, 0)
403+
404+
if "mel_stft" in flax_key and ("forward_basis" in flax_key or "inverse_basis" in flax_key):
405+
tensor = tensor.transpose(2, 1, 0)
401406

402407
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
403408

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
546546
)
547547
bwe_generator = Vocoder(
548548
upsample_initial_channel=512,
549+
upsample_kernel_sizes=[12, 11, 4, 4, 4],
549550
rngs=rngs,
550551
dtype=jnp.float32,
551552
)

0 commit comments

Comments
 (0)