Skip to content

Commit f93c3bd

Browse files
committed
Added running instructions
1 parent d1c304d commit f93c3bd

1 file changed

Lines changed: 13 additions & 0 deletions

File tree

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
### Transformer Pytorch Weight Downloading and Jax Weight Loading Instructions:
2+
1. Create new tansformers_pytorch folder under models/ltx_video.
3+
2. Move files attention.py, embeddings.py, symmetric_patchifier.py, transformer3d.py into the newly created folder.
4+
3. Rename transformer3d.py to transformer_pt.py to distinguish from the pytorch version. Change classname to Transformer3DModel_PT. Also change classname in line "transformer = Transformer3DModel.from_config(transformer_config)"
5+
4. Weight Downloading and Conversion
6+
- If first time running (no local safetensors): \
7+
In the src/maxdiffusion/models/ltx_video/utils folder, run python convert_torch_weights_to_jax.py --download_ckpt_path [location to download safetensors] --output_dir [location to save jax ckpt] --transformer_config_path ../xora_v1.2-13B-balanced-128.json.
8+
- If already have local pytorch checkpoint: \
9+
Replace the --download_ckpt_path with --local_ckpt_path and add corresponding location
10+
5. Restoring Jax Weights into transformer:
11+
- Replace the "ckpt_path" in src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json with jax ckpt path.
12+
- Run python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml in the outer repo folder.
13+

0 commit comments

Comments
 (0)