|
| 1 | +# Copyright 2025 Lightricks Ltd. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +# This implementation is based on the Torch version available at: |
| 16 | +# https://github.com/Lightricks/LTX-Video/tree/main |
| 17 | +import argparse |
| 18 | +import json |
| 19 | +from typing import Any, Dict, Optional |
| 20 | + |
| 21 | + |
| 22 | +import jax |
| 23 | +import jax.numpy as jnp |
| 24 | +from flax.training import train_state |
| 25 | +import optax |
| 26 | +import orbax.checkpoint as ocp |
| 27 | +from safetensors.torch import load_file |
| 28 | +import requests |
| 29 | +from urllib.parse import urljoin |
| 30 | + |
| 31 | +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel |
| 32 | +from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax |
| 33 | + |
| 34 | +from huggingface_hub import hf_hub_download |
| 35 | +import os |
| 36 | +import importlib |
| 37 | + |
| 38 | + |
| 39 | +def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import): |
| 40 | + """ |
| 41 | + Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module. |
| 42 | +
|
| 43 | + Args: |
| 44 | + github_base_url (str): The base URL of the GitHub repo. |
| 45 | + base_path (str): The base path where the new folder will be created. |
| 46 | + target_folder_name (str): The name of the folder to create. |
| 47 | + files_to_move (list): A list of file names to download and move. |
| 48 | + module_to_import (str): The full module path to import. |
| 49 | + """ |
| 50 | + |
| 51 | + target_path = os.path.join(base_path, target_folder_name) |
| 52 | + |
| 53 | + try: |
| 54 | + # Create the target directory |
| 55 | + os.makedirs(target_path, exist_ok=True) |
| 56 | + print(f"Created directory: {target_path}") |
| 57 | + |
| 58 | + # Download and move files |
| 59 | + for file_name in files_to_move: |
| 60 | + file_url = urljoin(github_base_url, file_name) |
| 61 | + destination_path = os.path.join(target_path, file_name) |
| 62 | + |
| 63 | + try: |
| 64 | + response = requests.get(file_url, stream=True) |
| 65 | + response.raise_for_status() |
| 66 | + |
| 67 | + with open(destination_path, "wb") as f: |
| 68 | + for chunk in response.iter_content(chunk_size=8192): |
| 69 | + f.write(chunk) |
| 70 | + |
| 71 | + print(f"Downloaded and moved: {file_name} -> {destination_path}") |
| 72 | + |
| 73 | + except requests.exceptions.RequestException as e: |
| 74 | + print(f"Error downloading {file_name}: {e}") |
| 75 | + except OSError as e: |
| 76 | + print(f"Error writing file {file_name}: {e}") |
| 77 | + print("Files downloaded and moved successfully.") |
| 78 | + |
| 79 | + # Verify that the folder exists |
| 80 | + if not os.path.exists(target_path): |
| 81 | + print(f"Error: Target folder {target_path} does not exist after files download.") |
| 82 | + # Dynamically import the module |
| 83 | + try: |
| 84 | + imported_module = importlib.import_module(module_to_import) |
| 85 | + print(f"Module '{module_to_import}' imported successfully.") |
| 86 | + # Access the class |
| 87 | + transformer_class = getattr(imported_module, "Transformer3DModel") |
| 88 | + print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}") |
| 89 | + return transformer_class |
| 90 | + except ImportError as e: |
| 91 | + print(f"Error importing module '{module_to_import}': {e}") |
| 92 | + except AttributeError as e: |
| 93 | + print(f"Error accessing class 'Transformer3DModel': {e}") |
| 94 | + |
| 95 | + except OSError as e: |
| 96 | + print(f"Error during file system operation: {e}") |
| 97 | + except Exception as e: |
| 98 | + print(f"An unexpected error occurred: {e}") |
| 99 | + |
| 100 | + |
| 101 | +class Checkpointer: |
| 102 | + """ |
| 103 | + Checkpointer - to load and store JAX checkpoints |
| 104 | + """ |
| 105 | + |
| 106 | + STATE_DICT_SHAPE_KEY = "shape" |
| 107 | + STATE_DICT_DTYPE_KEY = "dtype" |
| 108 | + TRAIN_STATE_FILE_NAME = "train_state" |
| 109 | + |
| 110 | + def __init__( |
| 111 | + self, |
| 112 | + checkpoint_dir: str, |
| 113 | + use_zarr3: bool = False, |
| 114 | + save_buffer_size: Optional[int] = None, |
| 115 | + restore_buffer_size: Optional[int] = None, |
| 116 | + ): |
| 117 | + """ |
| 118 | + Constructs the checkpointer object |
| 119 | + """ |
| 120 | + opts = ocp.CheckpointManagerOptions( |
| 121 | + enable_async_checkpointing=True, |
| 122 | + step_format_fixed_length=8, # to make the format of "00000000" |
| 123 | + ) |
| 124 | + self.use_zarr3 = use_zarr3 |
| 125 | + self.save_buffer_size = save_buffer_size |
| 126 | + self.restore_buffer_size = restore_buffer_size |
| 127 | + registry = ocp.DefaultCheckpointHandlerRegistry() |
| 128 | + self.train_state_handler = ocp.PyTreeCheckpointHandler( |
| 129 | + save_concurrent_gb=save_buffer_size, |
| 130 | + restore_concurrent_gb=restore_buffer_size, |
| 131 | + use_zarr3=use_zarr3, |
| 132 | + ) |
| 133 | + registry.add( |
| 134 | + self.TRAIN_STATE_FILE_NAME, |
| 135 | + ocp.args.PyTreeSave, |
| 136 | + self.train_state_handler, |
| 137 | + ) |
| 138 | + self.manager = ocp.CheckpointManager( |
| 139 | + directory=checkpoint_dir, |
| 140 | + options=opts, |
| 141 | + handler_registry=registry, |
| 142 | + ) |
| 143 | + |
| 144 | + @property |
| 145 | + def save_buffer_size_bytes(self) -> Optional[int]: |
| 146 | + if self.save_buffer_size is None: |
| 147 | + return None |
| 148 | + return self.save_buffer_size * 2**30 |
| 149 | + |
| 150 | + @staticmethod |
| 151 | + def state_dict_to_structure_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: |
| 152 | + """ |
| 153 | + Converts a state dict to a dictionary stating the shape and dtype of the state_dict elements. |
| 154 | + With this, we can reconstruct the state_dict structure later on. |
| 155 | + """ |
| 156 | + return jax.tree_util.tree_map( |
| 157 | + lambda t: { |
| 158 | + Checkpointer.STATE_DICT_SHAPE_KEY: tuple(t.shape), |
| 159 | + Checkpointer.STATE_DICT_DTYPE_KEY: t.dtype.name, |
| 160 | + }, |
| 161 | + state_dict, |
| 162 | + is_leaf=lambda t: isinstance(t, jax.Array), |
| 163 | + ) |
| 164 | + |
| 165 | + def save( |
| 166 | + self, |
| 167 | + step: int, |
| 168 | + state: train_state.TrainState, |
| 169 | + config: Dict[str, Any], |
| 170 | + ): |
| 171 | + """ |
| 172 | + Saves the checkpoint asynchronously |
| 173 | +
|
| 174 | + NOTE that state is going to be copied for this operation |
| 175 | +
|
| 176 | + Args: |
| 177 | + step (int): The step of the checkpoint |
| 178 | + state (TrainStateWithEma): A trainstate containing both the parameters and the optimizer state |
| 179 | + config (Dict[str, Any]): A dictionary containing the configuration of the model |
| 180 | + """ |
| 181 | + self.wait() |
| 182 | + args = ocp.args.Composite( |
| 183 | + train_state=ocp.args.PyTreeSave( |
| 184 | + state, |
| 185 | + ocdbt_target_data_file_size=self.save_buffer_size_bytes, |
| 186 | + ), |
| 187 | + config=ocp.args.JsonSave(config), |
| 188 | + meta_params=ocp.args.JsonSave(self.state_dict_to_structure_dict(state.params)), |
| 189 | + ) |
| 190 | + self.manager.save( |
| 191 | + step, |
| 192 | + args=args, |
| 193 | + ) |
| 194 | + |
| 195 | + def wait(self): |
| 196 | + """ |
| 197 | + Waits for the checkpoint save operation to complete |
| 198 | + """ |
| 199 | + self.manager.wait_until_finished() |
| 200 | + |
| 201 | + |
| 202 | +""" |
| 203 | +Convert Torch checkpoints to JAX. |
| 204 | +
|
| 205 | +This script loads a Torch checkpoint (either regular or sharded), converts it to Jax weights, and saved it. |
| 206 | +""" |
| 207 | + |
| 208 | + |
| 209 | +def main(args): |
| 210 | + """ |
| 211 | + Convert a Torch checkpoint into JAX. |
| 212 | + """ |
| 213 | + |
| 214 | + if args.output_step_num > 1: |
| 215 | + print( |
| 216 | + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " |
| 217 | + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " |
| 218 | + "training loss when resuming from the converted checkpoint." |
| 219 | + ) |
| 220 | + print("Downloading files from GitHub...") |
| 221 | + github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/" |
| 222 | + ltx_repo_path = "../" |
| 223 | + target_folder = "transformers_pytorch" |
| 224 | + files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"] |
| 225 | + module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d" |
| 226 | + |
| 227 | + Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path) |
| 228 | + |
| 229 | + print("Loading safetensors, flush = True") |
| 230 | + weight_file = "ltxv-13b-0.9.7-dev.safetensors" |
| 231 | + |
| 232 | + # download from huggingface, otherwise load from local |
| 233 | + |
| 234 | + print("Loading from HF", flush=True) |
| 235 | + model_name = "Lightricks/LTX-Video" |
| 236 | + absolute_ckpt_path = os.path.abspath(args.ckpt_path) |
| 237 | + local_file_path = hf_hub_download( |
| 238 | + repo_id=model_name, |
| 239 | + filename=weight_file, |
| 240 | + local_dir=absolute_ckpt_path, |
| 241 | + local_dir_use_symlinks=False, |
| 242 | + ) |
| 243 | + torch_state_dict = load_file(local_file_path) |
| 244 | + |
| 245 | + print("Initializing pytorch transformer..", flush=True) |
| 246 | + transformer_config = json.loads(open(args.transformer_config_path, "r").read()) |
| 247 | + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "ckpt_path"] |
| 248 | + for key in ignored_keys: |
| 249 | + if key in transformer_config: |
| 250 | + del transformer_config[key] |
| 251 | + |
| 252 | + transformer = Transformer3DModel.from_config(transformer_config) |
| 253 | + |
| 254 | + print("Loading torch weights into transformer..", flush=True) |
| 255 | + transformer.load_state_dict(torch_state_dict) |
| 256 | + torch_state_dict = transformer.state_dict() |
| 257 | + |
| 258 | + print("Creating jax transformer with params..", flush=True) |
| 259 | + transformer_config["use_tpu_flash_attention"] = True |
| 260 | + in_channels = transformer_config["in_channels"] |
| 261 | + del transformer_config["in_channels"] |
| 262 | + jax_transformer3d = JaxTranformer3DModel( |
| 263 | + **transformer_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch" |
| 264 | + ) |
| 265 | + example_inputs = {} |
| 266 | + batch_size, num_tokens = 2, 256 |
| 267 | + input_shapes = { |
| 268 | + "hidden_states": (batch_size, num_tokens, in_channels), |
| 269 | + "indices_grid": (batch_size, 3, num_tokens), |
| 270 | + "encoder_hidden_states": (batch_size, 128, transformer_config["caption_channels"]), |
| 271 | + "timestep": (batch_size, 256), |
| 272 | + "segment_ids": (batch_size, 256), |
| 273 | + "encoder_attention_segment_ids": (batch_size, 128), |
| 274 | + } |
| 275 | + for name, shape in input_shapes.items(): |
| 276 | + example_inputs[name] = jnp.ones( |
| 277 | + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool |
| 278 | + ) |
| 279 | + params_jax = jax_transformer3d.init(jax.random.PRNGKey(42), **example_inputs) |
| 280 | + |
| 281 | + print("Converting torch params to jax..", flush=True) |
| 282 | + params_jax = torch_statedict_to_jax(params_jax, torch_state_dict) |
| 283 | + |
| 284 | + print("Creating checkpointer and jax state for saving..", flush=True) |
| 285 | + relative_ckpt_path = os.path.join(args.ckpt_path, "jax_weights") |
| 286 | + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) |
| 287 | + tx = optax.adamw(learning_rate=1e-5) |
| 288 | + with jax.default_device("cpu"): |
| 289 | + state = train_state.TrainState( |
| 290 | + step=args.output_step_num, |
| 291 | + apply_fn=jax_transformer3d.apply, |
| 292 | + params=params_jax, |
| 293 | + tx=tx, |
| 294 | + opt_state=tx.init(params_jax), |
| 295 | + ) |
| 296 | + with ocp.CheckpointManager(absolute_ckpt_path) as mngr: |
| 297 | + mngr.save(args.output_step_num, args=ocp.args.StandardSave(state.params)) |
| 298 | + print("Done.", flush=True) |
| 299 | + |
| 300 | + |
| 301 | +if __name__ == "__main__": |
| 302 | + parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.") |
| 303 | + parser.add_argument( |
| 304 | + "--ckpt_path", |
| 305 | + type=str, |
| 306 | + required=False, |
| 307 | + help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'", |
| 308 | + ) |
| 309 | + |
| 310 | + parser.add_argument( |
| 311 | + "--output_step_num", |
| 312 | + default=1, |
| 313 | + type=int, |
| 314 | + required=False, |
| 315 | + help=( |
| 316 | + "The step number to assign to the output checkpoint. The result will be saved using this step value. " |
| 317 | + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " |
| 318 | + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " |
| 319 | + "training loss when resuming from the converted checkpoint." |
| 320 | + ), |
| 321 | + ) |
| 322 | + parser.add_argument( |
| 323 | + "--transformer_config_path", |
| 324 | + default="/opt/txt2img/txt2img/config/transformer3d/ltxv2B-v1.0.json", |
| 325 | + type=str, |
| 326 | + required=False, |
| 327 | + help="Path to Transformer3D structure config to load the weights based on.", |
| 328 | + ) |
| 329 | + |
| 330 | + args = parser.parse_args() |
| 331 | + main(args) |
0 commit comments