@@ -537,33 +537,35 @@ def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, confi
537537 max_logging .log ("Loading Vocoder..." )
538538
539539 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
540+ vocoder_repo = "Lightricks/LTX-2" if getattr (config , "model_name" , "" ) == "ltx2.3" else config .pretrained_model_name_or_path
540541 if getattr (config , "model_name" , "" ) == "ltx2.3" :
541542 vocoder_class = LTX2VocoderWithBWE
542543 else :
543544 vocoder_class = LTX2Vocoder
544-
545+
545546 vocoder = vocoder_class .from_config (
546- config . pretrained_model_name_or_path ,
547+ vocoder_repo ,
547548 subfolder = "vocoder" ,
548549 rngs = rngs ,
549550 mesh = mesh ,
550551 dtype = jnp .float32 ,
551552 weights_dtype = config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
552553 )
553554 return vocoder
554-
555+
555556 p_model_factory = partial (create_model , config = config )
556557 vocoder = nnx .eval_shape (p_model_factory , rngs = rngs )
557558 graphdef , state , rest_of_state = nnx .split (vocoder , nnx .Param , ...)
558559 rest_of_state = jax .tree_util .tree_map (cls ._init_dummy_shape , rest_of_state )
559-
560+
560561 logical_state_spec = nnx .get_partition_spec (state )
561562 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
562563 logical_state_sharding = dict (nnx .to_flat_state (logical_state_sharding ))
563564 params = state .to_pure_dict ()
564565 state = dict (nnx .to_flat_state (state ))
565-
566- params = load_vocoder_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = "vocoder" )
566+
567+ filename = "ltx-2.3-22b-dev.safetensors" if getattr (config , "model_name" , "" ) == "ltx2.3" else None
568+ params = load_vocoder_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = "vocoder" , filename = filename )
567569 if hasattr (config , "weights_dtype" ):
568570 params = jax .tree_util .tree_map (lambda x : x .astype (config .weights_dtype ), params )
569571
0 commit comments