Skip to content

Commit 9209e80

Browse files
committed
vocoder fix
1 parent 6e9f8ff commit 9209e80

1 file changed

Lines changed: 22 additions & 9 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ...schedulers import FlaxFlowMatchScheduler
3535
from ...models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
3636
from ...models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio
37-
from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder
37+
from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder, LTX2VocoderWithBWE
3838
from ...models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
3939
from ...models.ltx2.latent_upsampler_ltx2 import LTX2LatentUpsamplerModel
4040
from ...models.ltx2.ltx2_utils import (
@@ -482,14 +482,27 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi
482482
max_logging.log("Loading Vocoder...")
483483

484484
def create_model(rngs: nnx.Rngs, config: HyperParameters):
485-
vocoder = LTX2Vocoder.from_config(
486-
config.pretrained_model_name_or_path,
487-
subfolder="vocoder",
488-
rngs=rngs,
489-
mesh=mesh,
490-
dtype=jnp.float32,
491-
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
492-
)
485+
config_dict = LTX2Vocoder.load_config(config.pretrained_model_name_or_path, subfolder="vocoder")
486+
if "bwe_in_channels" in config_dict:
487+
max_logging.log("Instantiating LTX2VocoderWithBWE for LTX-2.3...")
488+
vocoder = LTX2VocoderWithBWE.from_config(
489+
config.pretrained_model_name_or_path,
490+
subfolder="vocoder",
491+
rngs=rngs,
492+
mesh=mesh,
493+
dtype=jnp.float32,
494+
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
495+
)
496+
else:
497+
max_logging.log("Instantiating LTX2Vocoder for LTX-2.0...")
498+
vocoder = LTX2Vocoder.from_config(
499+
config.pretrained_model_name_or_path,
500+
subfolder="vocoder",
501+
rngs=rngs,
502+
mesh=mesh,
503+
dtype=jnp.float32,
504+
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
505+
)
493506
return vocoder
494507

495508
p_model_factory = partial(create_model, config=config)

0 commit comments

Comments
 (0)