|
34 | 34 | from ...models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL |
35 | 35 | from ...models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio |
36 | 36 | from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder |
37 | | -from ...models.ltx2.vocoder_bwe_ltx2 import LTX2VocoderWithBWE |
| 37 | +from ...models.ltx2.vocoder_bwe_ltx2 import LTX2VocoderWithBWE, Vocoder, MelSTFT |
38 | 38 | from ...models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel |
39 | 39 | from ...models.ltx2.ltx2_3_utils import load_connectors_weights |
40 | 40 | from ...models.ltx2.ltx2_utils import ( |
@@ -537,20 +537,43 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi |
537 | 537 | max_logging.log("Loading Vocoder...") |
538 | 538 |
|
539 | 539 | def create_model(rngs: nnx.Rngs, config: HyperParameters): |
540 | | - vocoder_repo = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path |
541 | 540 | if getattr(config, "model_name", "") == "ltx2.3": |
542 | | - vocoder_class = LTX2VocoderWithBWE |
| 541 | + # Manually construct for LTX-2.3 to support BWE and avoid TypeError |
| 542 | + base_vocoder = Vocoder( |
| 543 | + upsample_initial_channel=1536, |
| 544 | + rngs=rngs, |
| 545 | + dtype=jnp.float32, |
| 546 | + ) |
| 547 | + bwe_generator = Vocoder( |
| 548 | + upsample_initial_channel=512, |
| 549 | + rngs=rngs, |
| 550 | + dtype=jnp.float32, |
| 551 | + ) |
| 552 | + mel_stft = MelSTFT( |
| 553 | + filter_length=512, |
| 554 | + hop_length=80, |
| 555 | + win_length=512, |
| 556 | + n_mel_channels=64, |
| 557 | + rngs=rngs, |
| 558 | + ) |
| 559 | + vocoder = LTX2VocoderWithBWE( |
| 560 | + vocoder=base_vocoder, |
| 561 | + bwe_generator=bwe_generator, |
| 562 | + mel_stft=mel_stft, |
| 563 | + input_sampling_rate=16000, |
| 564 | + output_sampling_rate=48000, |
| 565 | + hop_length=80, |
| 566 | + rngs=rngs, |
| 567 | + ) |
543 | 568 | else: |
544 | | - vocoder_class = LTX2Vocoder |
545 | | - |
546 | | - vocoder = vocoder_class.from_config( |
547 | | - vocoder_repo, |
548 | | - subfolder="vocoder", |
549 | | - rngs=rngs, |
550 | | - mesh=mesh, |
551 | | - dtype=jnp.float32, |
552 | | - weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32, |
553 | | - ) |
| 569 | + vocoder = LTX2Vocoder.from_config( |
| 570 | + config.pretrained_model_name_or_path, |
| 571 | + subfolder="vocoder", |
| 572 | + rngs=rngs, |
| 573 | + mesh=mesh, |
| 574 | + dtype=jnp.float32, |
| 575 | + weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32, |
| 576 | + ) |
554 | 577 | return vocoder |
555 | 578 |
|
556 | 579 | p_model_factory = partial(create_model, config=config) |
|
0 commit comments