Skip to content

Commit 18e06e8

Browse files
committed
changes made for pipeline components loading
1 parent 26466dc commit 18e06e8

5 files changed

Lines changed: 309 additions & 91 deletions

File tree

src/maxdiffusion/checkpointing/ltx2_checkpointer.py

Lines changed: 8 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -85,97 +85,18 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
8585
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.ltx2_state.keys()}")
8686
return restored_checkpoint, step
8787

88-
def load_diffusers_checkpoint(self):
89-
config = self.config
90-
max_logging.log("Loading LTX2 components from Hugging Face base models.")
91-
92-
# 1. Tokenizer
93-
max_logging.log("Loading Gemma Tokenizer...")
94-
tokenizer = AutoTokenizer.from_pretrained(
95-
config.pretrained_model_name_or_path,
96-
subfolder="tokenizer",
97-
)
98-
# 3. Connectors
99-
max_logging.log("Loading Connectors...")
100-
connectors = LTX2AudioVideoGemmaTextEncoder.from_pretrained(
101-
config.pretrained_model_name_or_path,
102-
subfolder="connectors",
103-
)
104-
105-
# 4. Video VAE
106-
max_logging.log("Loading Video VAE...")
107-
vae = LTX2VideoAutoencoderKL.from_pretrained(
108-
config.pretrained_model_name_or_path,
109-
subfolder="vae",
110-
)
111-
112-
# 5. Audio VAE
113-
max_logging.log("Loading Audio VAE...")
114-
audio_vae = FlaxAutoencoderKLLTX2Audio.from_pretrained(
115-
config.pretrained_model_name_or_path,
116-
subfolder="audio_vae",
117-
)
118-
119-
# 6. Transformer
120-
max_logging.log("Loading Transformer...")
121-
# NOTE: Transformer weights are usually sharded and loaded separately in generation scripts
122-
# This just instantiates the architecture wrapper or loads full weights.
123-
# In MaxDiffusion we typically let the pipeline or generation script handle sharded loading
124-
# but we load the raw config/eval shape here.
125-
transformer = LTX2VideoTransformer3DModel.from_pretrained(
126-
config.pretrained_model_name_or_path,
127-
subfolder="transformer",
128-
)
129-
130-
# 7. Vocoder
131-
max_logging.log("Loading Vocoder...")
132-
vocoder = LTX2Vocoder.from_pretrained(
133-
config.pretrained_model_name_or_path,
134-
subfolder="vocoder",
135-
)
136-
137-
# 8. Scheduler
138-
max_logging.log("Loading Scheduler...")
139-
scheduler = FlaxFlowMatchScheduler.from_pretrained(
140-
config.pretrained_model_name_or_path,
141-
subfolder="scheduler",
142-
)
143-
# 2. Text Encoder (PyTorch)
144-
max_logging.log("Loading Gemma3 Text Encoder...")
145-
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
146-
config.pretrained_model_name_or_path,
147-
subfolder="text_encoder",
148-
torch_dtype=torch.bfloat16,
149-
)
150-
text_encoder.eval()
151-
152-
153-
154-
pipeline = LTX2Pipeline(
155-
scheduler=scheduler,
156-
vae=vae,
157-
audio_vae=audio_vae,
158-
text_encoder=text_encoder,
159-
tokenizer=tokenizer,
160-
connectors=connectors,
161-
transformer=transformer,
162-
vocoder=vocoder,
163-
)
164-
165-
return pipeline
166-
167-
def load_checkpoint(self, step=None) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
88+
def load_checkpoint(self, step=None, vae_only=False, load_transformer=True) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
16889
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
16990
opt_state = None
91+
17092
if restored_checkpoint:
171-
max_logging.log("Loading LTX2 pipeline from checkpoint (TODO: implement fully if needed)")
172-
# pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint)
173-
# if "opt_state" in restored_checkpoint.ltx2_state.keys():
174-
# opt_state = restored_checkpoint.ltx2_state["opt_state"]
175-
pipeline = self.load_diffusers_checkpoint() # Fallback for now
93+
max_logging.log("Loading LTX2 pipeline from checkpoint")
94+
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
95+
if "opt_state" in restored_checkpoint.ltx2_state.keys():
96+
opt_state = restored_checkpoint.ltx2_state["opt_state"]
17697
else:
177-
max_logging.log("No checkpoint found, loading default pipeline.")
178-
pipeline = self.load_diffusers_checkpoint()
98+
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
99+
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
179100

180101
return pipeline, opt_state, step
181102

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,7 @@ def __call__(
11111111
return hidden_states
11121112

11131113

1114-
class LTX2VideoAutoencoderKL(nnx.Module, ConfigMixin):
1114+
class LTX2VideoAutoencoderKL(nnx.Module, FlaxModelMixin, ConfigMixin):
11151115
_supports_gradient_checkpointing = True
11161116
config_name = "config.json"
11171117

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2_audio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ...configuration_utils import ConfigMixin, register_to_config
1212
from ...utils import BaseOutput
1313
from ..vae_flax import FlaxDiagonalGaussianDistribution
14+
from ..modeling_flax_utils import FlaxModelMixin
1415

1516

1617
LATENT_DOWNSAMPLE_FACTOR = 4
@@ -624,7 +625,7 @@ def __call__(self, z, target_frames=None, target_mel_bins=None, train: bool = Fa
624625
return h
625626

626627

627-
class FlaxAutoencoderKLLTX2Audio(nnx.Module, ConfigMixin):
628+
class FlaxAutoencoderKLLTX2Audio(nnx.Module, FlaxModelMixin, ConfigMixin):
628629
"""
629630
LTX2 audio VAE wrapper handling normalization, patchification, and latent sampling.
630631
Operates in NHWC format (Batch, Time, Freq, Channels).

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jax.numpy as jnp
2222
from flax import nnx
2323
from ... import common_types
24+
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
25+
from maxdiffusion.models.modeling_flax_utils import FlaxModelMixin
2426

2527
Array = common_types.Array
2628
DType = common_types.DType
@@ -87,11 +89,12 @@ def __call__(self, x: Array) -> Array:
8789
return x
8890

8991

90-
class LTX2Vocoder(nnx.Module):
92+
class LTX2Vocoder(nnx.Module, FlaxModelMixin, ConfigMixin):
9193
"""
9294
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
9395
"""
9496

97+
@register_to_config
9598
def __init__(
9699
self,
97100
in_channels: int = 128,

0 commit comments

Comments
 (0)