Skip to content

Commit f6115df

Browse files
committed
auto script
1 parent 2737877 commit f6115df

1 file changed

Lines changed: 77 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,87 @@
22
import json
33
from typing import Any, Dict, Optional
44

5+
6+
57
import jax
68
import jax.numpy as jnp
79
from flax.training import train_state
810
import optax
911
import orbax.checkpoint as ocp
1012
from safetensors.torch import load_file
13+
import requests
14+
import shutil
15+
from urllib.parse import urljoin
1116

12-
from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT
17+
# from maxdiffusion.models.ltx_video.transformers_pytorch.transformer import Transformer3DModel
1318
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel
1419
from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax
1520

1621
from huggingface_hub import hf_hub_download
1722
import os
23+
import importlib
24+
def download_and_move_files(github_base_url, base_path, target_folder_name, files_to_move, module_to_import):
25+
"""
26+
Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module.
1827
28+
Args:
29+
github_base_url (str): The base URL of the GitHub repo.
30+
base_path (str): The base path where the new folder will be created.
31+
target_folder_name (str): The name of the folder to create.
32+
files_to_move (list): A list of file names to download and move.
33+
module_to_import (str): The full module path to import.
34+
"""
1935

36+
target_path = os.path.join(base_path, target_folder_name)
37+
38+
try:
39+
# Create the target directory
40+
os.makedirs(target_path, exist_ok=True)
41+
print(f"Created directory: {target_path}")
42+
43+
# Download and move files
44+
for file_name in files_to_move:
45+
file_url = urljoin(github_base_url, file_name)
46+
destination_path = os.path.join(target_path, file_name)
47+
48+
try:
49+
response = requests.get(file_url, stream=True)
50+
response.raise_for_status()
51+
52+
with open(destination_path, 'wb') as f:
53+
for chunk in response.iter_content(chunk_size=8192):
54+
f.write(chunk)
55+
56+
print(f"Downloaded and moved: {file_name} -> {destination_path}")
57+
58+
except requests.exceptions.RequestException as e:
59+
print(f"Error downloading {file_name}: {e}")
60+
return # Stop if there is an error.
61+
except OSError as e:
62+
print(f"Error writing file {file_name}: {e}")
63+
return # Stop if there is an error.
64+
print("Files downloaded and moved successfully.")
65+
66+
# Verify that the folder exists
67+
if not os.path.exists(target_path):
68+
print(f"Error: Target folder {target_path} does not exist after files download.")
69+
# Dynamically import the module
70+
try:
71+
imported_module = importlib.import_module(module_to_import)
72+
print(f"Module '{module_to_import}' imported successfully.")
73+
# Access the class
74+
transformer_class = getattr(imported_module, "Transformer3DModel")
75+
print(f"Class 'Transformer3DModel' accessed successfully: {transformer_class}")
76+
return transformer_class
77+
except ImportError as e:
78+
print(f"Error importing module '{module_to_import}': {e}")
79+
except AttributeError as e:
80+
print(f"Error accessing class 'Transformer3DModel': {e}")
81+
82+
except OSError as e:
83+
print(f"Error during file system operation: {e}")
84+
except Exception as e:
85+
print(f"An unexpected error occurred: {e}")
2086
class Checkpointer:
2187
"""
2288
Checkpointer - to load and store JAX checkpoints
@@ -136,7 +202,15 @@ def main(args):
136202
"the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in "
137203
"training loss when resuming from the converted checkpoint."
138204
)
139-
205+
print("Downloading files from GitHub...")
206+
github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/"
207+
ltx_repo_path = "../"
208+
target_folder = "transformers_pytorch"
209+
files = ["attention.py", "embeddings.py", "symmetric_patchifier.py", "transformer3d.py"]
210+
module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d"
211+
212+
Transformer3DModel = download_and_move_files(github_url, ltx_repo_path, target_folder, files, module_path)
213+
140214
print("Loading safetensors, flush = True")
141215
weight_file = "ltxv-13b-0.9.7-dev.safetensors"
142216

@@ -162,7 +236,7 @@ def main(args):
162236
if key in transformer_config:
163237
del transformer_config[key]
164238

165-
transformer = Transformer3DModel_PT.from_config(transformer_config)
239+
transformer = Transformer3DModel.from_config(transformer_config)
166240

167241
print("Loading torch weights into transformer..", flush=True)
168242
transformer.load_state_dict(torch_state_dict)

0 commit comments

Comments
 (0)