Skip to content

Commit 36242d2

Browse files
committed
changed input format
1 parent eaa7196 commit 36242d2

2 files changed

Lines changed: 14 additions & 36 deletions

File tree

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
11
### Transformer Pytorch Weight Downloading and Jax Weight Loading Instructions:
2-
1. Weight Downloading and Conversion
3-
- If first time running (no local safetensors): \
4-
In the src/maxdiffusion/models/ltx_video/utils folder, run python convert_torch_weights_to_jax.py --download_ckpt_path [location to download safetensors] --output_dir [location to save jax ckpt] --transformer_config_path ../xora_v1.2-13B-balanced-128.json.
5-
- If already have local pytorch checkpoint: \
6-
Replace the --download_ckpt_path with --local_ckpt_path and add corresponding location
7-
2. Restoring Jax Weights into transformer:
8-
- Replace the "ckpt_path" in src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json with jax ckpt path.
9-
- Run python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml in the outer repo folder.
10-
2+
In the folder src/maxdiffusion/models/ltx_video/utils, run:
3+
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../xora_v1.2-13B-balanced-128.json

src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,16 @@ def main(args):
230230
weight_file = "ltxv-13b-0.9.7-dev.safetensors"
231231

232232
# download from huggingface, otherwise load from local
233-
if args.local_ckpt_path is None:
234-
print("Loading from HF", flush=True)
235-
model_name = "Lightricks/LTX-Video"
236-
local_file_path = hf_hub_download(
237-
repo_id=model_name,
238-
filename=weight_file,
239-
local_dir=args.download_ckpt_path,
240-
local_dir_use_symlinks=False,
241-
)
242-
else:
243-
base_dir = args.local_ckpt_path
244-
local_file_path = os.path.join(base_dir, weight_file)
233+
234+
print("Loading from HF", flush=True)
235+
model_name = "Lightricks/LTX-Video"
236+
absolute_ckpt_path = os.path.abspath(args.ckpt_path)
237+
local_file_path = hf_hub_download(
238+
repo_id=model_name,
239+
filename=weight_file,
240+
local_dir=absolute_ckpt_path,
241+
local_dir_use_symlinks=False,
242+
)
245243
torch_state_dict = load_file(local_file_path)
246244

247245
print("Initializing pytorch transformer..", flush=True)
@@ -284,7 +282,7 @@ def main(args):
284282
params_jax = torch_statedict_to_jax(params_jax, torch_state_dict)
285283

286284
print("Creating checkpointer and jax state for saving..", flush=True)
287-
relative_ckpt_path = args.output_dir
285+
relative_ckpt_path = os.path.join(args.ckpt_path, "jax_weights")
288286
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
289287
tx = optax.adamw(learning_rate=1e-5)
290288
with jax.default_device("cpu"):
@@ -303,25 +301,12 @@ def main(args):
303301
if __name__ == "__main__":
304302
parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.")
305303
parser.add_argument(
306-
"--local_ckpt_path",
304+
"--ckpt_path",
307305
type=str,
308306
required=False,
309307
help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'",
310308
)
311309

312-
parser.add_argument(
313-
"--download_ckpt_path",
314-
type=str,
315-
required=False,
316-
help="Location to download safetensors from huggingface",
317-
)
318-
319-
parser.add_argument(
320-
"--output_dir",
321-
type=str,
322-
required=True,
323-
help="Path to save the checkpoint to. for example 'gs://lt-research-mm-europe-west4/jax_trainings/converted-from-torch'",
324-
)
325310
parser.add_argument(
326311
"--output_step_num",
327312
default=1,

0 commit comments

Comments
 (0)