Skip to content

Commit 9e802bb

Browse files
committed
construct LTX2VocoderWithBWE and vocoder and melstft
1 parent dfacf93 commit 9e802bb

1 file changed

Lines changed: 36 additions & 13 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ...models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
3535
from ...models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio
3636
from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder
37-
from ...models.ltx2.vocoder_bwe_ltx2 import LTX2VocoderWithBWE
37+
from ...models.ltx2.vocoder_bwe_ltx2 import LTX2VocoderWithBWE, Vocoder, MelSTFT
3838
from ...models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
3939
from ...models.ltx2.ltx2_3_utils import load_connectors_weights
4040
from ...models.ltx2.ltx2_utils import (
@@ -537,20 +537,43 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi
537537
max_logging.log("Loading Vocoder...")
538538

539539
def create_model(rngs: nnx.Rngs, config: HyperParameters):
540-
vocoder_repo = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
541540
if getattr(config, "model_name", "") == "ltx2.3":
542-
vocoder_class = LTX2VocoderWithBWE
541+
# Manually construct for LTX-2.3 to support BWE and avoid TypeError
542+
base_vocoder = Vocoder(
543+
upsample_initial_channel=1536,
544+
rngs=rngs,
545+
dtype=jnp.float32,
546+
)
547+
bwe_generator = Vocoder(
548+
upsample_initial_channel=512,
549+
rngs=rngs,
550+
dtype=jnp.float32,
551+
)
552+
mel_stft = MelSTFT(
553+
filter_length=512,
554+
hop_length=80,
555+
win_length=512,
556+
n_mel_channels=64,
557+
rngs=rngs,
558+
)
559+
vocoder = LTX2VocoderWithBWE(
560+
vocoder=base_vocoder,
561+
bwe_generator=bwe_generator,
562+
mel_stft=mel_stft,
563+
input_sampling_rate=16000,
564+
output_sampling_rate=48000,
565+
hop_length=80,
566+
rngs=rngs,
567+
)
543568
else:
544-
vocoder_class = LTX2Vocoder
545-
546-
vocoder = vocoder_class.from_config(
547-
vocoder_repo,
548-
subfolder="vocoder",
549-
rngs=rngs,
550-
mesh=mesh,
551-
dtype=jnp.float32,
552-
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
553-
)
569+
vocoder = LTX2Vocoder.from_config(
570+
config.pretrained_model_name_or_path,
571+
subfolder="vocoder",
572+
rngs=rngs,
573+
mesh=mesh,
574+
dtype=jnp.float32,
575+
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
576+
)
554577
return vocoder
555578

556579
p_model_factory = partial(create_model, config=config)

0 commit comments

Comments
 (0)