3636from ...models .ltx2 .vocoder_ltx2 import LTX2Vocoder
3737from ...models .ltx2 .vocoder_bwe_ltx2 import LTX2VocoderWithBWE , Vocoder , MelSTFT
3838from ...models .ltx2 .transformer_ltx2 import LTX2VideoTransformer3DModel
39- from ...models .ltx2 .ltx2_3_utils import load_connectors_weights_2_3 , load_vae_weights_2_3
4039from ...models .ltx2 .ltx2_utils import (
4140 load_transformer_weights ,
4241 load_vae_weights ,
4342 load_audio_vae_weights ,
4443 load_vocoder_weights ,
44+ load_connector_weights ,
4545)
4646from ...models .ltx2 .text_encoders .text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
4747from ...video_processor import VideoProcessor
@@ -364,40 +364,17 @@ def load_text_encoder(cls, config: HyperParameters):
364364 return text_encoder
365365
366366 @classmethod
367- def load_connectors (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters , tensors : dict = None ):
367+ def load_connectors (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
368368 max_logging .log ("Loading Connectors..." )
369369
370370 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
371- connector_kwargs = {
372- "dtype" : jnp .float32 ,
373- "weights_dtype" : config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
374- }
375- if getattr (config , "model_name" , "" ) == "ltx2.3" :
376- connector_kwargs .update (
377- {
378- "video_connector_num_layers" : 8 ,
379- "audio_connector_num_layers" : 8 ,
380- "caption_channels" : 3840 ,
381- "video_caption_channels" : 4096 ,
382- "audio_caption_channels" : 2048 ,
383- "video_connector_num_attention_heads" : 32 ,
384- "audio_connector_num_attention_heads" : 32 ,
385- "video_connector_attention_head_dim" : 128 ,
386- "audio_connector_attention_head_dim" : 64 ,
387- "video_gated_attn" : True ,
388- "audio_gated_attn" : True ,
389- "per_modality_projections" : True ,
390- "proj_bias" : True ,
391- "rope_type" : "split" ,
392- }
393- )
394- connector_repo = "Lightricks/LTX-2" if getattr (config , "model_name" , "" ) == "ltx2.3" else config .pretrained_model_name_or_path
395371 connectors = LTX2AudioVideoGemmaTextEncoder .from_config (
396- connector_repo ,
372+ config . pretrained_model_name_or_path ,
397373 subfolder = "connectors" ,
398374 rngs = rngs ,
399375 mesh = mesh ,
400- ** connector_kwargs ,
376+ dtype = jnp .float32 ,
377+ weights_dtype = config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
401378 )
402379 return connectors
403380
@@ -411,16 +388,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
411388 logical_state_sharding = dict (nnx .to_flat_state (logical_state_sharding ))
412389 params = state .to_pure_dict ()
413390 state = dict (nnx .to_flat_state (state ))
414- filename = "ltx-2.3-22b-dev.safetensors" if getattr (config , "model_name" , "" ) == "ltx2.3" else None
415- params = load_connectors_weights_2_3 (
416- config .pretrained_model_name_or_path ,
417- params ,
418- "cpu" ,
419- subfolder = "" ,
420- filename = filename ,
421- is_ltx2_3 = (getattr (config , "model_name" , "" ) == "ltx2.3" ),
422- tensors = tensors ,
423- )
391+
392+ params = load_connector_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = "connectors" )
424393 if hasattr (config , "weights_dtype" ):
425394 params = jax .tree_util .tree_map (lambda x : x .astype (config .weights_dtype ), params )
426395
@@ -437,46 +406,17 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
437406 return connectors
438407
439408 @classmethod
440- def load_vae (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters , tensors : dict = None ):
409+ def load_vae (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
441410 max_logging .log ("Loading Video VAE..." )
442411
443412 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
444- vae_kwargs = {
445- "dtype" : jnp .float32 ,
446- "weights_dtype" : config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
447- }
448- vae_repo = "Lightricks/LTX-2" if getattr (config , "model_name" , "" ) == "ltx2.3" else config .pretrained_model_name_or_path
449- if getattr (config , "model_name" , "" ) == "ltx2.3" :
450- vae_kwargs .update (
451- {
452- "block_out_channels" : (256 , 512 , 1024 , 1024 ),
453- "decoder_block_out_channels" : (256 , 512 , 512 , 1024 ),
454- "layers_per_block" : (4 , 6 , 4 , 2 , 2 ),
455- "decoder_layers_per_block" : (4 , 6 , 4 , 2 , 2 ),
456- "spatio_temporal_scaling" : (True , True , True , True ),
457- "decoder_spatio_temporal_scaling" : (True , True , True , True ),
458- "decoder_inject_noise" : (False , False , False , False , False ),
459- "downsample_type" : ("spatial" , "temporal" , "spatiotemporal" , "spatiotemporal" ),
460- "upsample_type" : ("spatiotemporal" , "spatiotemporal" , "temporal" , "spatial" ),
461- "upsample_residual" : (False , False , False , False ),
462- "upsample_factor" : (2 , 2 , 1 , 2 ),
463- "patch_size" : 4 ,
464- "patch_size_t" : 1 ,
465- "resnet_norm_eps" : 1e-6 ,
466- "encoder_causal" : True ,
467- "decoder_causal" : False ,
468- "encoder_spatial_padding_mode" : "zeros" ,
469- "decoder_spatial_padding_mode" : "zeros" ,
470- "spatial_compression_ratio" : 32 ,
471- "temporal_compression_ratio" : 8 ,
472- }
473- )
474413 vae = LTX2VideoAutoencoderKL .from_config (
475- vae_repo ,
414+ config . pretrained_model_name_or_path ,
476415 subfolder = "vae" ,
477416 rngs = rngs ,
478417 mesh = mesh ,
479- ** vae_kwargs ,
418+ dtype = jnp .float32 ,
419+ weights_dtype = config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
480420 )
481421 return vae
482422
@@ -491,12 +431,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
491431 params = state .to_pure_dict ()
492432 state = dict (nnx .to_flat_state (state ))
493433
494- if getattr (config , "model_name" , "" ) == "ltx2.3" :
495- params = load_vae_weights_2_3 (params , "cpu" , tensors )
496- else :
497- filename = "ltx-2.3-22b-dev.safetensors" if getattr (config , "model_name" , "" ) == "ltx2.3" else None
498- subfolder = "" if getattr (config , "model_name" , "" ) == "ltx2.3" else "vae"
499- params = load_vae_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = subfolder , filename = filename )
434+ params = load_vae_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = "vae" )
500435 if hasattr (config , "weights_dtype" ):
501436 params = jax .tree_util .tree_map (lambda x : x .astype (config .weights_dtype ), params )
502437
@@ -519,13 +454,12 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
519454 return vae
520455
521456 @classmethod
522- def load_audio_vae (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters , tensors : dict = None ):
457+ def load_audio_vae (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
523458 max_logging .log ("Loading Audio VAE..." )
524459
525460 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
526- vae_repo = "Lightricks/LTX-2" if getattr (config , "model_name" , "" ) == "ltx2.3" else config .pretrained_model_name_or_path
527461 audio_vae = FlaxAutoencoderKLLTX2Audio .from_config (
528- vae_repo ,
462+ config . pretrained_model_name_or_path ,
529463 subfolder = "audio_vae" ,
530464 rngs = rngs ,
531465 mesh = mesh ,
@@ -545,13 +479,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
545479 params = state .to_pure_dict ()
546480 state = dict (nnx .to_flat_state (state ))
547481
548- if tensors is not None and getattr (config , "model_name" , "" ) == "ltx2.3" :
549- from maxdiffusion .models .ltx2 .ltx2_3_utils import load_audio_vae_weights_2_3
550- params = load_audio_vae_weights_2_3 (params , "cpu" , tensors )
551- elif getattr (config , "model_name" , "" ) == "ltx2.3" :
552- params = load_audio_vae_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = "" , filename = "ltx-2.3-22b-dev.safetensors" )
553- else :
554- params = load_audio_vae_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = "audio_vae" )
482+ params = load_audio_vae_weights (config .pretrained_model_name_or_path , params , "cpu" , subfolder = "audio_vae" )
555483 if hasattr (config , "weights_dtype" ):
556484 params = jax .tree_util .tree_map (lambda x : x .astype (config .weights_dtype ), params )
557485
@@ -582,7 +510,6 @@ def load_transformer(
582510 config : HyperParameters ,
583511 restored_checkpoint = None ,
584512 subfolder = "transformer" ,
585- tensors : dict = None ,
586513 ):
587514 with mesh :
588515 transformer = create_sharded_logical_transformer (
@@ -592,45 +519,36 @@ def load_transformer(
592519 config = config ,
593520 restored_checkpoint = restored_checkpoint ,
594521 subfolder = subfolder ,
595- tensors = tensors ,
596522 )
597523 return transformer
598524
599525 @classmethod
600- def load_vocoder (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters , tensors : dict = None ):
526+ def load_vocoder (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
601527 max_logging .log ("Loading Vocoder..." )
602528
603529 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
604- vocoder_repo = "Lightricks/LTX-2" if getattr (config , "model_name" , "" ) == "ltx2.3" else config .pretrained_model_name_or_path
605530 vocoder = LTX2Vocoder .from_config (
606- vocoder_repo ,
531+ "Lightricks/LTX-2" ,
607532 subfolder = "vocoder" ,
608533 rngs = rngs ,
609534 mesh = mesh ,
610535 dtype = jnp .float32 ,
611536 weights_dtype = config .weights_dtype if hasattr (config , "weights_dtype" ) else jnp .float32 ,
612537 )
613538 return vocoder
614-
539+
615540 p_model_factory = partial (create_model , config = config )
616541 vocoder = nnx .eval_shape (p_model_factory , rngs = rngs )
617542 graphdef , state , rest_of_state = nnx .split (vocoder , nnx .Param , ...)
618543 rest_of_state = jax .tree_util .tree_map (cls ._init_dummy_shape , rest_of_state )
619-
544+
620545 logical_state_spec = nnx .get_partition_spec (state )
621546 logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
622547 logical_state_sharding = dict (nnx .to_flat_state (logical_state_sharding ))
623548 params = state .to_pure_dict ()
624549 state = dict (nnx .to_flat_state (state ))
625-
626- if tensors is not None and getattr (config , "model_name" , "" ) == "ltx2.3" :
627- from maxdiffusion .models .ltx2 .ltx2_utils import load_vocoder_weights
628- params = load_vocoder_weights ("Lightricks/LTX-2" , params , "cpu" , subfolder = "vocoder" )
629- else :
630- filename = "ltx-2.3-22b-dev.safetensors" if getattr (config , "model_name" , "" ) == "ltx2.3" else None
631- subfolder = "" if getattr (config , "model_name" , "" ) == "ltx2.3" else "vocoder"
632- repo_id = "Lightricks/LTX-2" if getattr (config , "model_name" , "" ) == "ltx2.3" else config .pretrained_model_name_or_path
633- params = load_vocoder_weights (repo_id , params , "cpu" , subfolder = subfolder , filename = filename )
550+
551+ params = load_vocoder_weights ("Lightricks/LTX-2" , params , "cpu" , subfolder = "vocoder" )
634552 if hasattr (config , "weights_dtype" ):
635553 params = jax .tree_util .tree_map (lambda x : x .astype (config .weights_dtype ), params )
636554
@@ -657,7 +575,7 @@ def load_scheduler(cls, config: HyperParameters):
657575 return scheduler
658576
659577 @classmethod
660- def _create_common_components (cls , config : HyperParameters , vae_only = False , segregated_weights = None ):
578+ def _create_common_components (cls , config : HyperParameters , vae_only = False ):
661579 devices_array = max_utils .create_device_mesh (config )
662580 mesh = Mesh (devices_array , config .mesh_axes )
663581 rng = jax .random .key (config .seed )
@@ -668,7 +586,6 @@ def _create_common_components(cls, config: HyperParameters, vae_only=False, segr
668586 mesh ,
669587 rngs ,
670588 config ,
671- tensors = segregated_weights .get ("vae" ) if segregated_weights else None
672589 )
673590
674591 components = {
@@ -694,37 +611,25 @@ def _create_common_components(cls, config: HyperParameters, vae_only=False, segr
694611 mesh ,
695612 rngs ,
696613 config ,
697- tensors = segregated_weights .get ("connectors" ) if segregated_weights else None
698614 )
699615 components ["audio_vae" ] = cls .load_audio_vae (
700616 devices_array ,
701617 mesh ,
702618 rngs ,
703619 config ,
704- tensors = segregated_weights .get ("audio_vae" ) if segregated_weights else None
705620 )
706621 components ["vocoder" ] = cls .load_vocoder (
707622 devices_array ,
708623 mesh ,
709624 rngs ,
710625 config ,
711- tensors = segregated_weights .get ("vocoder" ) if segregated_weights else None
712626 )
713627 components ["scheduler" ] = cls .load_scheduler (config )
714628 return components
715629
716630 @classmethod
717631 def _load_and_init (cls , config : HyperParameters , restored_checkpoint , vae_only = False , load_transformer = True ):
718- segregated_weights = None
719- if getattr (config , "model_name" , "" ) == "ltx2.3" :
720- from maxdiffusion .models .ltx2 .ltx2_3_utils import load_and_segregate_ltx2_3_weights
721- max_logging .log ("Loading consolidated LTX-2.3 weights..." )
722- segregated_weights = load_and_segregate_ltx2_3_weights (
723- config .pretrained_model_name_or_path ,
724- filename = "ltx-2.3-22b-dev.safetensors"
725- )
726-
727- components = cls ._create_common_components (config , vae_only , segregated_weights = segregated_weights )
632+ components = cls ._create_common_components (config , vae_only )
728633
729634 transformer = None
730635 if load_transformer :
@@ -735,7 +640,6 @@ def _load_and_init(cls, config: HyperParameters, restored_checkpoint, vae_only=F
735640 rngs = components ["rngs" ],
736641 config = config ,
737642 restored_checkpoint = restored_checkpoint ,
738- tensors = segregated_weights .get ("transformer" ) if segregated_weights else None ,
739643 )
740644
741645 pipeline = cls (
0 commit comments