Skip to content

Commit 99a1e1e

Browse files
committed
vocoder weight
1 parent 9e802bb commit 99a1e1e

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
588588
state = dict(nnx.to_flat_state(state))
589589

590590
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
591-
params = load_vocoder_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="vocoder", filename=filename)
591+
subfolder = "" if getattr(config, "model_name", "") == "ltx2.3" else "vocoder"
592+
params = load_vocoder_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder=subfolder, filename=filename)
592593
if hasattr(config, "weights_dtype"):
593594
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
594595

0 commit comments

Comments
 (0)