Skip to content

Commit 8cac9af

Browse files
committed
checkpoint switched
1 parent 346a127 commit 8cac9af

2 files changed

Lines changed: 24 additions & 120 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ replicate_vae: False
9393
allow_split_physical_axes: False
9494
learning_rate_schedule_steps: -1
9595
max_train_steps: 500
96-
pretrained_model_name_or_path: 'Lightricks/LTX-2.3'
96+
pretrained_model_name_or_path: 'dg845/LTX-2.3-Diffusers'
9797
model_name: "ltx2.3"
9898
model_type: "T2V"
9999
unet_checkpoint: ''

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 23 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@
3636
from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder
3737
from ...models.ltx2.vocoder_bwe_ltx2 import LTX2VocoderWithBWE, Vocoder, MelSTFT
3838
from ...models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
39-
from ...models.ltx2.ltx2_3_utils import load_connectors_weights_2_3, load_vae_weights_2_3
4039
from ...models.ltx2.ltx2_utils import (
4140
load_transformer_weights,
4241
load_vae_weights,
4342
load_audio_vae_weights,
4443
load_vocoder_weights,
44+
load_connector_weights,
4545
)
4646
from ...models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
4747
from ...video_processor import VideoProcessor
@@ -364,40 +364,17 @@ def load_text_encoder(cls, config: HyperParameters):
364364
return text_encoder
365365

366366
@classmethod
367-
def load_connectors(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, tensors: dict = None):
367+
def load_connectors(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
368368
max_logging.log("Loading Connectors...")
369369

370370
def create_model(rngs: nnx.Rngs, config: HyperParameters):
371-
connector_kwargs = {
372-
"dtype": jnp.float32,
373-
"weights_dtype": config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
374-
}
375-
if getattr(config, "model_name", "") == "ltx2.3":
376-
connector_kwargs.update(
377-
{
378-
"video_connector_num_layers": 8,
379-
"audio_connector_num_layers": 8,
380-
"caption_channels": 3840,
381-
"video_caption_channels": 4096,
382-
"audio_caption_channels": 2048,
383-
"video_connector_num_attention_heads": 32,
384-
"audio_connector_num_attention_heads": 32,
385-
"video_connector_attention_head_dim": 128,
386-
"audio_connector_attention_head_dim": 64,
387-
"video_gated_attn": True,
388-
"audio_gated_attn": True,
389-
"per_modality_projections": True,
390-
"proj_bias": True,
391-
"rope_type": "split",
392-
}
393-
)
394-
connector_repo = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
395371
connectors = LTX2AudioVideoGemmaTextEncoder.from_config(
396-
connector_repo,
372+
config.pretrained_model_name_or_path,
397373
subfolder="connectors",
398374
rngs=rngs,
399375
mesh=mesh,
400-
**connector_kwargs,
376+
dtype=jnp.float32,
377+
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
401378
)
402379
return connectors
403380

@@ -411,16 +388,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
411388
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
412389
params = state.to_pure_dict()
413390
state = dict(nnx.to_flat_state(state))
414-
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
415-
params = load_connectors_weights_2_3(
416-
config.pretrained_model_name_or_path,
417-
params,
418-
"cpu",
419-
subfolder="",
420-
filename=filename,
421-
is_ltx2_3=(getattr(config, "model_name", "") == "ltx2.3"),
422-
tensors=tensors,
423-
)
391+
392+
params = load_connector_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="connectors")
424393
if hasattr(config, "weights_dtype"):
425394
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
426395

@@ -437,46 +406,17 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
437406
return connectors
438407

439408
@classmethod
440-
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, tensors: dict = None):
409+
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
441410
max_logging.log("Loading Video VAE...")
442411

443412
def create_model(rngs: nnx.Rngs, config: HyperParameters):
444-
vae_kwargs = {
445-
"dtype": jnp.float32,
446-
"weights_dtype": config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
447-
}
448-
vae_repo = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
449-
if getattr(config, "model_name", "") == "ltx2.3":
450-
vae_kwargs.update(
451-
{
452-
"block_out_channels": (256, 512, 1024, 1024),
453-
"decoder_block_out_channels": (256, 512, 512, 1024),
454-
"layers_per_block": (4, 6, 4, 2, 2),
455-
"decoder_layers_per_block": (4, 6, 4, 2, 2),
456-
"spatio_temporal_scaling": (True, True, True, True),
457-
"decoder_spatio_temporal_scaling": (True, True, True, True),
458-
"decoder_inject_noise": (False, False, False, False, False),
459-
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
460-
"upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"),
461-
"upsample_residual": (False, False, False, False),
462-
"upsample_factor": (2, 2, 1, 2),
463-
"patch_size": 4,
464-
"patch_size_t": 1,
465-
"resnet_norm_eps": 1e-6,
466-
"encoder_causal": True,
467-
"decoder_causal": False,
468-
"encoder_spatial_padding_mode": "zeros",
469-
"decoder_spatial_padding_mode": "zeros",
470-
"spatial_compression_ratio": 32,
471-
"temporal_compression_ratio": 8,
472-
}
473-
)
474413
vae = LTX2VideoAutoencoderKL.from_config(
475-
vae_repo,
414+
config.pretrained_model_name_or_path,
476415
subfolder="vae",
477416
rngs=rngs,
478417
mesh=mesh,
479-
**vae_kwargs,
418+
dtype=jnp.float32,
419+
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
480420
)
481421
return vae
482422

@@ -491,12 +431,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
491431
params = state.to_pure_dict()
492432
state = dict(nnx.to_flat_state(state))
493433

494-
if getattr(config, "model_name", "") == "ltx2.3":
495-
params = load_vae_weights_2_3(params, "cpu", tensors)
496-
else:
497-
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
498-
subfolder = "" if getattr(config, "model_name", "") == "ltx2.3" else "vae"
499-
params = load_vae_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder=subfolder, filename=filename)
434+
params = load_vae_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="vae")
500435
if hasattr(config, "weights_dtype"):
501436
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
502437

@@ -519,13 +454,12 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
519454
return vae
520455

521456
@classmethod
522-
def load_audio_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, tensors: dict = None):
457+
def load_audio_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
523458
max_logging.log("Loading Audio VAE...")
524459

525460
def create_model(rngs: nnx.Rngs, config: HyperParameters):
526-
vae_repo = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
527461
audio_vae = FlaxAutoencoderKLLTX2Audio.from_config(
528-
vae_repo,
462+
config.pretrained_model_name_or_path,
529463
subfolder="audio_vae",
530464
rngs=rngs,
531465
mesh=mesh,
@@ -545,13 +479,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
545479
params = state.to_pure_dict()
546480
state = dict(nnx.to_flat_state(state))
547481

548-
if tensors is not None and getattr(config, "model_name", "") == "ltx2.3":
549-
from maxdiffusion.models.ltx2.ltx2_3_utils import load_audio_vae_weights_2_3
550-
params = load_audio_vae_weights_2_3(params, "cpu", tensors)
551-
elif getattr(config, "model_name", "") == "ltx2.3":
552-
params = load_audio_vae_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="", filename="ltx-2.3-22b-dev.safetensors")
553-
else:
554-
params = load_audio_vae_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="audio_vae")
482+
params = load_audio_vae_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="audio_vae")
555483
if hasattr(config, "weights_dtype"):
556484
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
557485

@@ -582,7 +510,6 @@ def load_transformer(
582510
config: HyperParameters,
583511
restored_checkpoint=None,
584512
subfolder="transformer",
585-
tensors: dict = None,
586513
):
587514
with mesh:
588515
transformer = create_sharded_logical_transformer(
@@ -592,45 +519,36 @@ def load_transformer(
592519
config=config,
593520
restored_checkpoint=restored_checkpoint,
594521
subfolder=subfolder,
595-
tensors=tensors,
596522
)
597523
return transformer
598524

599525
@classmethod
600-
def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, tensors: dict = None):
526+
def load_vocoder(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
601527
max_logging.log("Loading Vocoder...")
602528

603529
def create_model(rngs: nnx.Rngs, config: HyperParameters):
604-
vocoder_repo = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
605530
vocoder = LTX2Vocoder.from_config(
606-
vocoder_repo,
531+
"Lightricks/LTX-2",
607532
subfolder="vocoder",
608533
rngs=rngs,
609534
mesh=mesh,
610535
dtype=jnp.float32,
611536
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
612537
)
613538
return vocoder
614-
539+
615540
p_model_factory = partial(create_model, config=config)
616541
vocoder = nnx.eval_shape(p_model_factory, rngs=rngs)
617542
graphdef, state, rest_of_state = nnx.split(vocoder, nnx.Param, ...)
618543
rest_of_state = jax.tree_util.tree_map(cls._init_dummy_shape, rest_of_state)
619-
544+
620545
logical_state_spec = nnx.get_partition_spec(state)
621546
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
622547
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
623548
params = state.to_pure_dict()
624549
state = dict(nnx.to_flat_state(state))
625-
626-
if tensors is not None and getattr(config, "model_name", "") == "ltx2.3":
627-
from maxdiffusion.models.ltx2.ltx2_utils import load_vocoder_weights
628-
params = load_vocoder_weights("Lightricks/LTX-2", params, "cpu", subfolder="vocoder")
629-
else:
630-
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
631-
subfolder = "" if getattr(config, "model_name", "") == "ltx2.3" else "vocoder"
632-
repo_id = "Lightricks/LTX-2" if getattr(config, "model_name", "") == "ltx2.3" else config.pretrained_model_name_or_path
633-
params = load_vocoder_weights(repo_id, params, "cpu", subfolder=subfolder, filename=filename)
550+
551+
params = load_vocoder_weights("Lightricks/LTX-2", params, "cpu", subfolder="vocoder")
634552
if hasattr(config, "weights_dtype"):
635553
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
636554

@@ -657,7 +575,7 @@ def load_scheduler(cls, config: HyperParameters):
657575
return scheduler
658576

659577
@classmethod
660-
def _create_common_components(cls, config: HyperParameters, vae_only=False, segregated_weights=None):
578+
def _create_common_components(cls, config: HyperParameters, vae_only=False):
661579
devices_array = max_utils.create_device_mesh(config)
662580
mesh = Mesh(devices_array, config.mesh_axes)
663581
rng = jax.random.key(config.seed)
@@ -668,7 +586,6 @@ def _create_common_components(cls, config: HyperParameters, vae_only=False, segr
668586
mesh,
669587
rngs,
670588
config,
671-
tensors=segregated_weights.get("vae") if segregated_weights else None
672589
)
673590

674591
components = {
@@ -694,37 +611,25 @@ def _create_common_components(cls, config: HyperParameters, vae_only=False, segr
694611
mesh,
695612
rngs,
696613
config,
697-
tensors=segregated_weights.get("connectors") if segregated_weights else None
698614
)
699615
components["audio_vae"] = cls.load_audio_vae(
700616
devices_array,
701617
mesh,
702618
rngs,
703619
config,
704-
tensors=segregated_weights.get("audio_vae") if segregated_weights else None
705620
)
706621
components["vocoder"] = cls.load_vocoder(
707622
devices_array,
708623
mesh,
709624
rngs,
710625
config,
711-
tensors=segregated_weights.get("vocoder") if segregated_weights else None
712626
)
713627
components["scheduler"] = cls.load_scheduler(config)
714628
return components
715629

716630
@classmethod
717631
def _load_and_init(cls, config: HyperParameters, restored_checkpoint, vae_only=False, load_transformer=True):
718-
segregated_weights = None
719-
if getattr(config, "model_name", "") == "ltx2.3":
720-
from maxdiffusion.models.ltx2.ltx2_3_utils import load_and_segregate_ltx2_3_weights
721-
max_logging.log("Loading consolidated LTX-2.3 weights...")
722-
segregated_weights = load_and_segregate_ltx2_3_weights(
723-
config.pretrained_model_name_or_path,
724-
filename="ltx-2.3-22b-dev.safetensors"
725-
)
726-
727-
components = cls._create_common_components(config, vae_only, segregated_weights=segregated_weights)
632+
components = cls._create_common_components(config, vae_only)
728633

729634
transformer = None
730635
if load_transformer:
@@ -735,7 +640,6 @@ def _load_and_init(cls, config: HyperParameters, restored_checkpoint, vae_only=F
735640
rngs=components["rngs"],
736641
config=config,
737642
restored_checkpoint=restored_checkpoint,
738-
tensors=segregated_weights.get("transformer") if segregated_weights else None,
739643
)
740644

741645
pipeline = cls(

0 commit comments

Comments
 (0)