Skip to content

Commit acd501c

Browse files
committed
checkpointer and generation script added
1 parent 8c75d37 commit acd501c

3 files changed

Lines changed: 413 additions & 2 deletions

File tree

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
import jax
19+
import numpy as np
20+
from typing import Optional, Tuple
21+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import LTX2Pipeline
22+
from maxdiffusion.models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
23+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
24+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio
25+
from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder
26+
from maxdiffusion.models.ltx2.vocoder_ltx2 import LTX2Vocoder
27+
from maxdiffusion.schedulers.scheduling_flow_match_flax import FlaxFlowMatchScheduler
28+
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
29+
from maxdiffusion import max_logging, max_utils
30+
from maxdiffusion.checkpointing.checkpointing_utils import create_orbax_checkpoint_manager
31+
import orbax.checkpoint as ocp
32+
from etils import epath
33+
import torch
34+
35+
LTX2_CHECKPOINT = "LTX2_CHECKPOINT"
36+
37+
class LTX2Checkpointer:
38+
39+
def __init__(self, config, checkpoint_type: str = LTX2_CHECKPOINT):
40+
self.config = config
41+
self.checkpoint_type = checkpoint_type
42+
self.opt_state = None
43+
44+
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
45+
self.config.checkpoint_dir,
46+
enable_checkpointing=True,
47+
save_interval_steps=1,
48+
checkpoint_type=checkpoint_type,
49+
dataset_type=getattr(config, "dataset_type", None),
50+
)
51+
52+
def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
53+
if step is None:
54+
step = self.checkpoint_manager.latest_step()
55+
max_logging.log(f"Latest LTX2 checkpoint step: {step}")
56+
if step is None:
57+
max_logging.log("No LTX2 checkpoint found.")
58+
return None, None
59+
max_logging.log(f"Loading LTX2 checkpoint from step {step}")
60+
metadatas = self.checkpoint_manager.item_metadata(step)
61+
transformer_metadata = metadatas.ltx2_state
62+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
63+
params_restore = ocp.args.PyTreeRestore(
64+
restore_args=jax.tree.map(
65+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
66+
abstract_tree_structure_params,
67+
)
68+
)
69+
70+
max_logging.log("Restoring LTX2 checkpoint")
71+
restored_checkpoint = self.checkpoint_manager.restore(
72+
directory=epath.Path(self.config.checkpoint_dir),
73+
step=step,
74+
args=ocp.args.Composite(
75+
ltx2_state=params_restore,
76+
ltx2_config=ocp.args.JsonRestore(),
77+
),
78+
)
79+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
80+
max_logging.log(f"restored checkpoint ltx2_state {restored_checkpoint.ltx2_state.keys()}")
81+
max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.ltx2_state.keys()}")
82+
return restored_checkpoint, step
83+
84+
def load_diffusers_checkpoint(self):
85+
config = self.config
86+
max_logging.log("Loading LTX2 components from Hugging Face base models.")
87+
88+
# 1. Tokenizer
89+
max_logging.log("Loading Gemma Tokenizer...")
90+
tokenizer = AutoTokenizer.from_pretrained(
91+
config.pretrained_model_name_or_path,
92+
subfolder="tokenizer",
93+
)
94+
95+
# 2. Text Encoder (PyTorch)
96+
max_logging.log("Loading Gemma3 Text Encoder...")
97+
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
98+
config.pretrained_model_name_or_path,
99+
subfolder="text_encoder",
100+
torch_dtype=torch.bfloat16,
101+
)
102+
text_encoder.eval()
103+
104+
# 3. Connectors
105+
max_logging.log("Loading Connectors...")
106+
connectors = LTX2AudioVideoGemmaTextEncoder.from_pretrained(
107+
config.pretrained_model_name_or_path,
108+
subfolder="connectors",
109+
)
110+
111+
# 4. Video VAE
112+
max_logging.log("Loading Video VAE...")
113+
vae = LTX2VideoAutoencoderKL.from_pretrained(
114+
config.pretrained_model_name_or_path,
115+
subfolder="vae",
116+
)
117+
118+
# 5. Audio VAE
119+
max_logging.log("Loading Audio VAE...")
120+
audio_vae = FlaxAutoencoderKLLTX2Audio.from_pretrained(
121+
config.pretrained_model_name_or_path,
122+
subfolder="audio_vae",
123+
)
124+
125+
# 6. Transformer
126+
max_logging.log("Loading Transformer...")
127+
# NOTE: Transformer weights are usually sharded and loaded separately in generation scripts
128+
# This just instantiates the architecture wrapper or loads full weights.
129+
# In MaxDiffusion we typically let the pipeline or generation script handle sharded loading
130+
# but we load the raw config/eval shape here.
131+
transformer = LTX2VideoTransformer3DModel.from_pretrained(
132+
config.pretrained_model_name_or_path,
133+
subfolder="transformer",
134+
)
135+
136+
# 7. Vocoder
137+
max_logging.log("Loading Vocoder...")
138+
vocoder = LTX2Vocoder.from_pretrained(
139+
config.pretrained_model_name_or_path,
140+
subfolder="vocoder",
141+
)
142+
143+
# 8. Scheduler
144+
max_logging.log("Loading Scheduler...")
145+
scheduler = FlaxFlowMatchScheduler.from_pretrained(
146+
config.pretrained_model_name_or_path,
147+
subfolder="scheduler",
148+
)
149+
150+
pipeline = LTX2Pipeline(
151+
scheduler=scheduler,
152+
vae=vae,
153+
audio_vae=audio_vae,
154+
text_encoder=text_encoder,
155+
tokenizer=tokenizer,
156+
connectors=connectors,
157+
transformer=transformer,
158+
vocoder=vocoder,
159+
)
160+
161+
return pipeline
162+
163+
def load_checkpoint(self, step=None) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
164+
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
165+
opt_state = None
166+
if restored_checkpoint:
167+
max_logging.log("Loading LTX2 pipeline from checkpoint (TODO: implement fully if needed)")
168+
# pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint)
169+
# if "opt_state" in restored_checkpoint.ltx2_state.keys():
170+
# opt_state = restored_checkpoint.ltx2_state["opt_state"]
171+
pipeline = self.load_diffusers_checkpoint() # Fallback for now
172+
else:
173+
max_logging.log("No checkpoint found, loading default pipeline.")
174+
pipeline = self.load_diffusers_checkpoint()
175+
176+
return pipeline, opt_state, step
177+
178+
def save_checkpoint(self, train_step, pipeline: LTX2Pipeline, train_states: dict):
179+
"""Saves the training state and model configurations."""
180+
181+
def config_to_json(model_or_config):
182+
return json.loads(model_or_config.to_json_string())
183+
184+
max_logging.log(f"Saving checkpoint for step {train_step}")
185+
items = {
186+
"ltx2_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
187+
}
188+
189+
items["ltx2_state"] = ocp.args.PyTreeSave(train_states)
190+
191+
# Save the checkpoint
192+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
193+
max_logging.log(f"Checkpoint for step {train_step} saved.")
194+

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ max_sequence_length: 512
2323
sampler: "from_checkpoint"
2424

2525
# Generation parameters
26+
global_batch_size_to_train_on: 1
27+
num_inference_steps: 40
28+
guidance_scale: 3.0
29+
fps: 24
2630
pipeline_type: multi-scale
2731
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
2832
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
@@ -92,11 +96,13 @@ ici_tensor_parallelism: 1
9296
allow_split_physical_axes: False
9397
learning_rate_schedule_steps: -1
9498
max_train_steps: 500
95-
pretrained_model_name_or_path: ''
99+
pretrained_model_name_or_path: 'Lightricks/LTX-Video'
100+
model_name: "ltx_video"
101+
model_type: "T2V"
96102
unet_checkpoint: ''
97103
dataset_name: 'diffusers/pokemon-gpt4-captions'
98104
train_split: 'train'
99-
dataset_type: 'tf'
105+
dataset_type: 'tfrecord'
100106
cache_latents_text_encoder_outputs: True
101107
per_device_batch_size: 1
102108
compile_topology_num_slices: -1

0 commit comments

Comments
 (0)