Skip to content

Commit b4bd96e

Browse files
committed
merged conversion
2 parents 774e2c4 + 8fc3626 commit b4bd96e

2 files changed

Lines changed: 867 additions & 0 deletions

File tree

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# Copyright 2025 Lightricks Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This implementation is based on the Torch version available at:
16+
# https://github.com/Lightricks/LTX-Video/tree/main
17+
import argparse
18+
import json
19+
from typing import Any, Dict, Optional
20+
21+
22+
import jax
23+
import jax.numpy as jnp
24+
from flax.training import train_state
25+
import optax
26+
import orbax.checkpoint as ocp
27+
from safetensors.torch import load_file
28+
import requests
29+
from urllib.parse import urljoin
30+
31+
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel
32+
from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax
33+
34+
from huggingface_hub import hf_hub_download
35+
import os
36+
import importlib
37+
38+
39+
def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import):
40+
"""
41+
Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module.
42+
43+
Args:
44+
github_base_url (str): The base URL of the GitHub repo.
45+
base_path (str): The base path where the new folder will be created.
46+
target_folder_name (str): The name of the folder to create.
47+
files_to_move (list): A list of file names to download and move.
48+
module_to_import (str): The full module path to import.
49+
"""
50+
51+
target_path = os.path.join(base_path, target_folder_name)
52+
53+
try:
54+
# Create the target directory
55+
os.makedirs(target_path, exist_ok=True)
56+
print(f"Created directory: {target_path}")
57+
58+
# Download and move files
59+
for file_name in files_to_move:
60+
file_url = urljoin(github_base_url, file_name)
61+
destination_path = os.path.join(target_path, file_name)
62+
63+
try:
64+
response = requests.get(file_url, stream=True)
65+
response.raise_for_status()
66+
67+
with open(destination_path, "wb") as f:
68+
for chunk in response.iter_content(chunk_size=8192):
69+
f.write(chunk)
70+
71+
print(f"Downloaded and moved: {file_name} -> {destination_path}")
72+
73+
except requests.exceptions.RequestException as e:
74+
print(f"Error downloading {file_name}: {e}")
75+
except OSError as e:
76+
print(f"Error writing file {file_name}: {e}")
77+
print("Files downloaded and moved successfully.")
78+
79+
# Verify that the folder exists
80+
if not os.path.exists(target_path):
81+
print(f"Error: Target folder {target_path} does not exist after files download.")
82+
# Dynamically import the module
83+
try:
84+
imported_module = importlib.import_module(module_to_import)
85+
print(f"Module '{module_to_import}' imported successfully.")
86+
# Access the class
87+
transformer_class = getattr(imported_module, "Transformer3DModel")
88+
print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}")
89+
return transformer_class
90+
except ImportError as e:
91+
print(f"Error importing module '{module_to_import}': {e}")
92+
except AttributeError as e:
93+
print(f"Error accessing class 'Transformer3DModel': {e}")
94+
95+
except OSError as e:
96+
print(f"Error during file system operation: {e}")
97+
except Exception as e:
98+
print(f"An unexpected error occurred: {e}")
99+
100+
101+
class Checkpointer:
102+
"""
103+
Checkpointer - to load and store JAX checkpoints
104+
"""
105+
106+
STATE_DICT_SHAPE_KEY = "shape"
107+
STATE_DICT_DTYPE_KEY = "dtype"
108+
TRAIN_STATE_FILE_NAME = "train_state"
109+
110+
def __init__(
111+
self,
112+
checkpoint_dir: str,
113+
use_zarr3: bool = False,
114+
save_buffer_size: Optional[int] = None,
115+
restore_buffer_size: Optional[int] = None,
116+
):
117+
"""
118+
Constructs the checkpointer object
119+
"""
120+
opts = ocp.CheckpointManagerOptions(
121+
enable_async_checkpointing=True,
122+
step_format_fixed_length=8, # to make the format of "00000000"
123+
)
124+
self.use_zarr3 = use_zarr3
125+
self.save_buffer_size = save_buffer_size
126+
self.restore_buffer_size = restore_buffer_size
127+
registry = ocp.DefaultCheckpointHandlerRegistry()
128+
self.train_state_handler = ocp.PyTreeCheckpointHandler(
129+
save_concurrent_gb=save_buffer_size,
130+
restore_concurrent_gb=restore_buffer_size,
131+
use_zarr3=use_zarr3,
132+
)
133+
registry.add(
134+
self.TRAIN_STATE_FILE_NAME,
135+
ocp.args.PyTreeSave,
136+
self.train_state_handler,
137+
)
138+
self.manager = ocp.CheckpointManager(
139+
directory=checkpoint_dir,
140+
options=opts,
141+
handler_registry=registry,
142+
)
143+
144+
@property
145+
def save_buffer_size_bytes(self) -> Optional[int]:
146+
if self.save_buffer_size is None:
147+
return None
148+
return self.save_buffer_size * 2**30
149+
150+
@staticmethod
151+
def state_dict_to_structure_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
152+
"""
153+
Converts a state dict to a dictionary stating the shape and dtype of the state_dict elements.
154+
With this, we can reconstruct the state_dict structure later on.
155+
"""
156+
return jax.tree_util.tree_map(
157+
lambda t: {
158+
Checkpointer.STATE_DICT_SHAPE_KEY: tuple(t.shape),
159+
Checkpointer.STATE_DICT_DTYPE_KEY: t.dtype.name,
160+
},
161+
state_dict,
162+
is_leaf=lambda t: isinstance(t, jax.Array),
163+
)
164+
165+
def save(
166+
self,
167+
step: int,
168+
state: train_state.TrainState,
169+
config: Dict[str, Any],
170+
):
171+
"""
172+
Saves the checkpoint asynchronously
173+
174+
NOTE that state is going to be copied for this operation
175+
176+
Args:
177+
step (int): The step of the checkpoint
178+
state (TrainStateWithEma): A trainstate containing both the parameters and the optimizer state
179+
config (Dict[str, Any]): A dictionary containing the configuration of the model
180+
"""
181+
self.wait()
182+
args = ocp.args.Composite(
183+
train_state=ocp.args.PyTreeSave(
184+
state,
185+
ocdbt_target_data_file_size=self.save_buffer_size_bytes,
186+
),
187+
config=ocp.args.JsonSave(config),
188+
meta_params=ocp.args.JsonSave(self.state_dict_to_structure_dict(state.params)),
189+
)
190+
self.manager.save(
191+
step,
192+
args=args,
193+
)
194+
195+
def wait(self):
196+
"""
197+
Waits for the checkpoint save operation to complete
198+
"""
199+
self.manager.wait_until_finished()
200+
201+
202+
"""
203+
Convert Torch checkpoints to JAX.
204+
205+
This script loads a Torch checkpoint (either regular or sharded), converts it to Jax weights, and saved it.
206+
"""
207+
208+
209+
def main(args):
210+
"""
211+
Convert a Torch checkpoint into JAX.
212+
"""
213+
214+
if args.output_step_num > 1:
215+
print(
216+
"⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between "
217+
"the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in "
218+
"training loss when resuming from the converted checkpoint."
219+
)
220+
print("Downloading files from GitHub...")
221+
github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/"
222+
ltx_repo_path = "../"
223+
target_folder = "transformers_pytorch"
224+
files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"]
225+
module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d"
226+
227+
Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path)
228+
229+
print("Loading safetensors, flush = True")
230+
weight_file = "ltxv-13b-0.9.7-dev.safetensors"
231+
232+
# download from huggingface, otherwise load from local
233+
234+
print("Loading from HF", flush=True)
235+
model_name = "Lightricks/LTX-Video"
236+
absolute_ckpt_path = os.path.abspath(args.ckpt_path)
237+
local_file_path = hf_hub_download(
238+
repo_id=model_name,
239+
filename=weight_file,
240+
local_dir=absolute_ckpt_path,
241+
local_dir_use_symlinks=False,
242+
)
243+
torch_state_dict = load_file(local_file_path)
244+
245+
print("Initializing pytorch transformer..", flush=True)
246+
transformer_config = json.loads(open(args.transformer_config_path, "r").read())
247+
ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "ckpt_path"]
248+
for key in ignored_keys:
249+
if key in transformer_config:
250+
del transformer_config[key]
251+
252+
transformer = Transformer3DModel.from_config(transformer_config)
253+
254+
print("Loading torch weights into transformer..", flush=True)
255+
transformer.load_state_dict(torch_state_dict)
256+
torch_state_dict = transformer.state_dict()
257+
258+
print("Creating jax transformer with params..", flush=True)
259+
transformer_config["use_tpu_flash_attention"] = True
260+
in_channels = transformer_config["in_channels"]
261+
del transformer_config["in_channels"]
262+
jax_transformer3d = JaxTranformer3DModel(
263+
**transformer_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch"
264+
)
265+
example_inputs = {}
266+
batch_size, num_tokens = 2, 256
267+
input_shapes = {
268+
"hidden_states": (batch_size, num_tokens, in_channels),
269+
"indices_grid": (batch_size, 3, num_tokens),
270+
"encoder_hidden_states": (batch_size, 128, transformer_config["caption_channels"]),
271+
"timestep": (batch_size, 256),
272+
"segment_ids": (batch_size, 256),
273+
"encoder_attention_segment_ids": (batch_size, 128),
274+
}
275+
for name, shape in input_shapes.items():
276+
example_inputs[name] = jnp.ones(
277+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
278+
)
279+
params_jax = jax_transformer3d.init(jax.random.PRNGKey(42), **example_inputs)
280+
281+
print("Converting torch params to jax..", flush=True)
282+
params_jax = torch_statedict_to_jax(params_jax, torch_state_dict)
283+
284+
print("Creating checkpointer and jax state for saving..", flush=True)
285+
relative_ckpt_path = os.path.join(args.ckpt_path, "jax_weights")
286+
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
287+
tx = optax.adamw(learning_rate=1e-5)
288+
with jax.default_device("cpu"):
289+
state = train_state.TrainState(
290+
step=args.output_step_num,
291+
apply_fn=jax_transformer3d.apply,
292+
params=params_jax,
293+
tx=tx,
294+
opt_state=tx.init(params_jax),
295+
)
296+
with ocp.CheckpointManager(absolute_ckpt_path) as mngr:
297+
mngr.save(args.output_step_num, args=ocp.args.StandardSave(state.params))
298+
print("Done.", flush=True)
299+
300+
301+
if __name__ == "__main__":
302+
parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.")
303+
parser.add_argument(
304+
"--ckpt_path",
305+
type=str,
306+
required=False,
307+
help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'",
308+
)
309+
310+
parser.add_argument(
311+
"--output_step_num",
312+
default=1,
313+
type=int,
314+
required=False,
315+
help=(
316+
"The step number to assign to the output checkpoint. The result will be saved using this step value. "
317+
"⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between "
318+
"the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in "
319+
"training loss when resuming from the converted checkpoint."
320+
),
321+
)
322+
parser.add_argument(
323+
"--transformer_config_path",
324+
default="/opt/txt2img/txt2img/config/transformer3d/ltxv2B-v1.0.json",
325+
type=str,
326+
required=False,
327+
help="Path to Transformer3D structure config to load the weights based on.",
328+
)
329+
330+
args = parser.parse_args()
331+
main(args)

0 commit comments

Comments
 (0)