Skip to content

Commit ba3b738

Browse files
committed
vocoder weights changed
1 parent 261a492 commit ba3b738

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,12 +625,13 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
625625
state = dict(nnx.to_flat_state(state))
626626

627627
if tensors is not None and getattr(config, "model_name", "") == "ltx2.3":
628-
from maxdiffusion.models.ltx2.ltx2_3_utils import load_vocoder_weights_2_3
629-
params = load_vocoder_weights_2_3(params, "cpu", tensors)
628+
from maxdiffusion.models.ltx2.ltx2_utils import load_vocoder_weights
629+
params = load_vocoder_weights("Lightricks/LTX-2", params, "cpu", subfolder="vocoder")
630630
else:
631631
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
632632
subfolder = "" if getattr(config, "model_name", "") == "ltx2.3" else "vocoder"
633-
params = load_vocoder_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder=subfolder, filename=filename)
633+
repo_id = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
634+
params = load_vocoder_weights(repo_id, params, "cpu", subfolder=subfolder, filename=filename)
634635
if hasattr(config, "weights_dtype"):
635636
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
636637

0 commit comments

Comments
 (0)