Skip to content

Commit e847423

Browse files
committed
vocoder fix
1 parent 595bfd7 commit e847423

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,12 +485,13 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
485485
config_dict = LTX2Vocoder.load_config(config.pretrained_model_name_or_path, subfolder="vocoder")
486486
if "bwe_in_channels" in config_dict:
487487
max_logging.log("Instantiating LTX2VocoderWithBWE for LTX-2.3...")
488-
# Remove model_type if present to avoid from_config creating wrong class
489-
config_dict.pop("model_type", None)
490-
vocoder = LTX2VocoderWithBWE(
491-
**config_dict,
488+
vocoder = LTX2VocoderWithBWE.from_config(
489+
config.pretrained_model_name_or_path,
490+
subfolder="vocoder",
492491
rngs=rngs,
492+
mesh=mesh,
493493
dtype=jnp.float32,
494+
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
494495
)
495496
else:
496497
max_logging.log("Instantiating LTX2Vocoder for LTX-2.0...")

0 commit comments

Comments
 (0)