Skip to content

Commit b31a97b

Browse files
committed
conversion script added
1 parent 7bed4f9 commit b31a97b

4 files changed

Lines changed: 959 additions & 0 deletions

File tree

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import argparse
2+
import json
3+
from typing import Any, Dict, Optional
4+
5+
import jax
6+
import jax.numpy as jnp
7+
from flax.training import train_state
8+
import optax
9+
import orbax.checkpoint as ocp
10+
from safetensors.torch import load_file
11+
12+
from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT
13+
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel
14+
from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax
15+
16+
from huggingface_hub import hf_hub_download
17+
import os
18+
19+
20+
class Checkpointer:
21+
"""
22+
Checkpointer - to load and store JAX checkpoints
23+
"""
24+
25+
STATE_DICT_SHAPE_KEY = "shape"
26+
STATE_DICT_DTYPE_KEY = "dtype"
27+
TRAIN_STATE_FILE_NAME = "train_state"
28+
29+
def __init__(
30+
self,
31+
checkpoint_dir: str,
32+
use_zarr3: bool = False,
33+
save_buffer_size: Optional[int] = None,
34+
restore_buffer_size: Optional[int] = None,
35+
):
36+
"""
37+
Constructs the checkpointer object
38+
"""
39+
opts = ocp.CheckpointManagerOptions(
40+
enable_async_checkpointing=True,
41+
step_format_fixed_length=8, # to make the format of "00000000"
42+
)
43+
self.use_zarr3 = use_zarr3
44+
self.save_buffer_size = save_buffer_size
45+
self.restore_buffer_size = restore_buffer_size
46+
registry = ocp.DefaultCheckpointHandlerRegistry()
47+
self.train_state_handler = ocp.PyTreeCheckpointHandler(
48+
save_concurrent_gb=save_buffer_size,
49+
restore_concurrent_gb=restore_buffer_size,
50+
use_zarr3=use_zarr3,
51+
)
52+
registry.add(
53+
self.TRAIN_STATE_FILE_NAME,
54+
ocp.args.PyTreeSave,
55+
self.train_state_handler,
56+
)
57+
self.manager = ocp.CheckpointManager(
58+
directory=checkpoint_dir,
59+
options=opts,
60+
handler_registry=registry,
61+
)
62+
63+
@property
64+
def save_buffer_size_bytes(self) -> Optional[int]:
65+
if self.save_buffer_size is None:
66+
return None
67+
return self.save_buffer_size * 2**30
68+
69+
@staticmethod
70+
def state_dict_to_structure_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
71+
"""
72+
Converts a state dict to a dictionary stating the shape and dtype of the state_dict elements.
73+
With this, we can reconstruct the state_dict structure later on.
74+
"""
75+
return jax.tree_util.tree_map(
76+
lambda t: {
77+
Checkpointer.STATE_DICT_SHAPE_KEY: tuple(t.shape),
78+
Checkpointer.STATE_DICT_DTYPE_KEY: t.dtype.name,
79+
},
80+
state_dict,
81+
is_leaf=lambda t: isinstance(t, jax.Array),
82+
)
83+
84+
def save(
85+
self,
86+
step: int,
87+
state: train_state.TrainState,
88+
config: Dict[str, Any],
89+
):
90+
"""
91+
Saves the checkpoint asynchronously
92+
93+
NOTE that state is going to be copied for this operation
94+
95+
Args:
96+
step (int): The step of the checkpoint
97+
state (TrainStateWithEma): A trainstate containing both the parameters and the optimizer state
98+
config (Dict[str, Any]): A dictionary containing the configuration of the model
99+
"""
100+
self.wait()
101+
args = ocp.args.Composite(
102+
train_state=ocp.args.PyTreeSave(
103+
state,
104+
ocdbt_target_data_file_size=self.save_buffer_size_bytes,
105+
),
106+
config=ocp.args.JsonSave(config),
107+
meta_params=ocp.args.JsonSave(self.state_dict_to_structure_dict(state.params)),
108+
)
109+
self.manager.save(
110+
step,
111+
args=args,
112+
)
113+
114+
def wait(self):
115+
"""
116+
Waits for the checkpoint save operation to complete
117+
"""
118+
self.manager.wait_until_finished()
119+
120+
121+
"""
122+
Convert Torch checkpoints to JAX.
123+
124+
This script loads a Torch checkpoint (either regular or sharded), converts it to Jax weights, and saved it.
125+
"""
126+
127+
128+
def main(args):
129+
"""
130+
Convert a Torch checkpoint into JAX.
131+
"""
132+
133+
if args.output_step_num > 1:
134+
print(
135+
"⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between "
136+
"the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in "
137+
"training loss when resuming from the converted checkpoint."
138+
)
139+
140+
print("Loading safetensors, flush = True")
141+
weight_file = "ltxv-13b-0.9.7-dev.safetensors"
142+
143+
# download from huggingface, otherwise load from local
144+
if args.local_ckpt_path is None:
145+
print("Loading from HF", flush=True)
146+
model_name = "Lightricks/LTX-Video"
147+
local_file_path = hf_hub_download(
148+
repo_id=model_name,
149+
filename=weight_file,
150+
local_dir=args.download_ckpt_path,
151+
local_dir_use_symlinks=False,
152+
)
153+
else:
154+
base_dir = args.local_ckpt_path
155+
local_file_path = os.path.join(base_dir, weight_file)
156+
torch_state_dict = load_file(local_file_path)
157+
158+
print("Initializing pytorch transformer..", flush=True)
159+
transformer_config = json.loads(open(args.transformer_config_path, "r").read())
160+
ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "ckpt_path"]
161+
for key in ignored_keys:
162+
if key in transformer_config:
163+
del transformer_config[key]
164+
165+
transformer = Transformer3DModel_PT.from_config(transformer_config)
166+
167+
print("Loading torch weights into transformer..", flush=True)
168+
transformer.load_state_dict(torch_state_dict)
169+
torch_state_dict = transformer.state_dict()
170+
171+
print("Creating jax transformer with params..", flush=True)
172+
transformer_config["use_tpu_flash_attention"] = True
173+
in_channels = transformer_config["in_channels"]
174+
del transformer_config["in_channels"]
175+
jax_transformer3d = JaxTranformer3DModel(
176+
**transformer_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch"
177+
)
178+
example_inputs = {}
179+
batch_size, num_tokens = 2, 256
180+
input_shapes = {
181+
"hidden_states": (batch_size, num_tokens, in_channels),
182+
"indices_grid": (batch_size, 3, num_tokens),
183+
"encoder_hidden_states": (batch_size, 128, transformer_config["caption_channels"]),
184+
"timestep": (batch_size, 256),
185+
"segment_ids": (batch_size, 256),
186+
"encoder_attention_segment_ids": (batch_size, 128),
187+
}
188+
for name, shape in input_shapes.items():
189+
example_inputs[name] = jnp.ones(
190+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
191+
)
192+
params_jax = jax_transformer3d.init(jax.random.PRNGKey(42), **example_inputs)
193+
194+
print("Converting torch params to jax..", flush=True)
195+
params_jax = torch_statedict_to_jax(params_jax, torch_state_dict)
196+
197+
print("Creating checkpointer and jax state for saving..", flush=True)
198+
relative_ckpt_path = args.output_dir
199+
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
200+
tx = optax.adamw(learning_rate=1e-5)
201+
with jax.default_device("cpu"):
202+
state = train_state.TrainState(
203+
step=args.output_step_num,
204+
apply_fn=jax_transformer3d.apply,
205+
params=params_jax,
206+
tx=tx,
207+
opt_state=tx.init(params_jax),
208+
)
209+
with ocp.CheckpointManager(absolute_ckpt_path) as mngr:
210+
mngr.save(args.output_step_num, args=ocp.args.StandardSave(state.params))
211+
print("Done.", flush=True)
212+
213+
214+
if __name__ == "__main__":
215+
parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.")
216+
parser.add_argument(
217+
"--local_ckpt_path",
218+
type=str,
219+
required=False,
220+
help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'",
221+
)
222+
223+
parser.add_argument(
224+
"--download_ckpt_path",
225+
type=str,
226+
required=False,
227+
help="Location to download safetensors from huggingface",
228+
)
229+
230+
parser.add_argument(
231+
"--output_dir",
232+
type=str,
233+
required=True,
234+
help="Path to save the checkpoint to. for example 'gs://lt-research-mm-europe-west4/jax_trainings/converted-from-torch'",
235+
)
236+
parser.add_argument(
237+
"--output_step_num",
238+
default=1,
239+
type=int,
240+
required=False,
241+
help=(
242+
"The step number to assign to the output checkpoint. The result will be saved using this step value. "
243+
"⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between "
244+
"the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in "
245+
"training loss when resuming from the converted checkpoint."
246+
),
247+
)
248+
parser.add_argument(
249+
"--transformer_config_path",
250+
default="/opt/txt2img/txt2img/config/transformer3d/ltxv2B-v1.0.json",
251+
type=str,
252+
required=False,
253+
help="Path to Transformer3D structure config to load the weights based on.",
254+
)
255+
256+
args = parser.parse_args()
257+
main(args)

0 commit comments

Comments
 (0)