|
34 | 34 | from ...schedulers import FlaxFlowMatchScheduler |
35 | 35 | from ...models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL |
36 | 36 | from ...models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio |
37 | | -from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder |
| 37 | +from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder, LTX2VocoderWithBWE |
38 | 38 | from ...models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel |
39 | 39 | from ...models.ltx2.latent_upsampler_ltx2 import LTX2LatentUpsamplerModel |
40 | 40 | from ...models.ltx2.ltx2_utils import ( |
@@ -482,14 +482,27 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi |
482 | 482 | max_logging.log("Loading Vocoder...") |
483 | 483 |
|
484 | 484 | def create_model(rngs: nnx.Rngs, config: HyperParameters): |
485 | | - vocoder = LTX2Vocoder.from_config( |
486 | | - config.pretrained_model_name_or_path, |
487 | | - subfolder="vocoder", |
488 | | - rngs=rngs, |
489 | | - mesh=mesh, |
490 | | - dtype=jnp.float32, |
491 | | - weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32, |
492 | | - ) |
| 485 | + config_dict = LTX2Vocoder.load_config(config.pretrained_model_name_or_path, subfolder="vocoder") |
| 486 | + if "bwe_in_channels" in config_dict: |
| 487 | + max_logging.log("Instantiating LTX2VocoderWithBWE for LTX-2.3...") |
| 488 | + vocoder = LTX2VocoderWithBWE.from_config( |
| 489 | + config.pretrained_model_name_or_path, |
| 490 | + subfolder="vocoder", |
| 491 | + rngs=rngs, |
| 492 | + mesh=mesh, |
| 493 | + dtype=jnp.float32, |
| 494 | + weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32, |
| 495 | + ) |
| 496 | + else: |
| 497 | + max_logging.log("Instantiating LTX2Vocoder for LTX-2.0...") |
| 498 | + vocoder = LTX2Vocoder.from_config( |
| 499 | + config.pretrained_model_name_or_path, |
| 500 | + subfolder="vocoder", |
| 501 | + rngs=rngs, |
| 502 | + mesh=mesh, |
| 503 | + dtype=jnp.float32, |
| 504 | + weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32, |
| 505 | + ) |
493 | 506 | return vocoder |
494 | 507 |
|
495 | 508 | p_model_factory = partial(create_model, config=config) |
|
0 commit comments