@@ -625,12 +625,13 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
625625 state = dict (nnx .to_flat_state (state ))
626626
627627 if tensors is not None and getattr (config , "model_name" , "" ) == "ltx2.3" :
628- from maxdiffusion .models .ltx2 .ltx2_3_utils import load_vocoder_weights_2_3
629- params = load_vocoder_weights_2_3 ( params , "cpu" , tensors )
628+ from maxdiffusion .models .ltx2 .ltx2_utils import load_vocoder_weights
629+ params = load_vocoder_weights ( "Lightricks/LTX-2" , params , "cpu" , subfolder = "vocoder" )
630630 else :
631631 filename = "ltx-2.3-22b-dev.safetensors" if getattr (config , "model_name" , "" ) == "ltx2.3" else None
632632 subfolder = "" if getattr (config , "model_name" , "" ) == "ltx2.3" else "vocoder"
633- params = load_vocoder_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = subfolder , filename = filename )
633+ repo_id = "Lightricks/LTX-2" if getattr (config , "model_name" , "" ) == "ltx2.3" else config .pretrained_model_name_or_path
634+ params = load_vocoder_weights (repo_id , params , "cpu" , subfolder = subfolder , filename = filename )
634635 if hasattr (config , "weights_dtype" ):
635636 params = jax .tree_util .tree_map (lambda x : x .astype (config .weights_dtype ), params )
636637
0 commit comments