22import json
33from typing import Any , Dict , Optional
44
5+
6+
57import jax
68import jax .numpy as jnp
79from flax .training import train_state
810import optax
911import orbax .checkpoint as ocp
1012from safetensors .torch import load_file
13+ import requests
14+ import shutil
15+ from urllib .parse import urljoin
1116
12- from maxdiffusion .models .ltx_video .transformers_pytorch .transformer_pt import Transformer3DModel_PT
17+ # from maxdiffusion.models.ltx_video.transformers_pytorch.transformer import Transformer3DModel
1318from maxdiffusion .models .ltx_video .transformers .transformer3d import Transformer3DModel as JaxTranformer3DModel
1419from maxdiffusion .models .ltx_video .utils .torch_compat import torch_statedict_to_jax
1520
1621from huggingface_hub import hf_hub_download
1722import os
23+ import importlib
24+ def download_and_move_files (github_base_url , base_path , target_folder_name , files_to_move , module_to_import ):
25+ """
26+ Downloads files from a GitHub repo, moves them to a local folder, and then dynamically imports a module.
1827
28+ Args:
29+ github_base_url (str): The base URL of the GitHub repo.
30+ base_path (str): The base path where the new folder will be created.
31+ target_folder_name (str): The name of the folder to create.
32+ files_to_move (list): A list of file names to download and move.
33+ module_to_import (str): The full module path to import.
34+ """
1935
36+ target_path = os .path .join (base_path , target_folder_name )
37+
38+ try :
39+ # Create the target directory
40+ os .makedirs (target_path , exist_ok = True )
41+ print (f"Created directory: { target_path } " )
42+
43+ # Download and move files
44+ for file_name in files_to_move :
45+ file_url = urljoin (github_base_url , file_name )
46+ destination_path = os .path .join (target_path , file_name )
47+
48+ try :
49+ response = requests .get (file_url , stream = True )
50+ response .raise_for_status ()
51+
52+ with open (destination_path , 'wb' ) as f :
53+ for chunk in response .iter_content (chunk_size = 8192 ):
54+ f .write (chunk )
55+
56+ print (f"Downloaded and moved: { file_name } -> { destination_path } " )
57+
58+ except requests .exceptions .RequestException as e :
59+ print (f"Error downloading { file_name } : { e } " )
60+ return # Stop if there is an error.
61+ except OSError as e :
62+ print (f"Error writing file { file_name } : { e } " )
63+ return # Stop if there is an error.
64+ print ("Files downloaded and moved successfully." )
65+
66+ # Verify that the folder exists
67+ if not os .path .exists (target_path ):
68+ print (f"Error: Target folder { target_path } does not exist after files download." )
69+ # Dynamically import the module
70+ try :
71+ imported_module = importlib .import_module (module_to_import )
72+ print (f"Module '{ module_to_import } ' imported successfully." )
73+ # Access the class
74+ transformer_class = getattr (imported_module , "Transformer3DModel" )
75+ print (f"Class 'Transformer3DModel' accessed successfully: { transformer_class } " )
76+ return transformer_class
77+ except ImportError as e :
78+ print (f"Error importing module '{ module_to_import } ': { e } " )
79+ except AttributeError as e :
80+ print (f"Error accessing class 'Transformer3DModel': { e } " )
81+
82+ except OSError as e :
83+ print (f"Error during file system operation: { e } " )
84+ except Exception as e :
85+ print (f"An unexpected error occurred: { e } " )
2086class Checkpointer :
2187 """
2288 Checkpointer - to load and store JAX checkpoints
@@ -136,7 +202,15 @@ def main(args):
136202 "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in "
137203 "training loss when resuming from the converted checkpoint."
138204 )
139-
205+ print ("Downloading files from GitHub..." )
206+ github_url = "https://raw.githubusercontent.com/Lightricks/LTX-Video/main/ltx_video/models/transformers/"
207+ ltx_repo_path = "../"
208+ target_folder = "transformers_pytorch"
209+ files = ["attention.py" , "embeddings.py" , "symmetric_patchifier.py" , "transformer3d.py" ]
210+ module_path = "maxdiffusion.models.ltx_video.transformers_pytorch.transformer3d"
211+
212+ Transformer3DModel = download_and_move_files (github_url , ltx_repo_path , target_folder , files , module_path )
213+
140214 print ("Loading safetensors, flush = True" )
141215 weight_file = "ltxv-13b-0.9.7-dev.safetensors"
142216
@@ -162,7 +236,7 @@ def main(args):
162236 if key in transformer_config :
163237 del transformer_config [key ]
164238
165- transformer = Transformer3DModel_PT .from_config (transformer_config )
239+ transformer = Transformer3DModel .from_config (transformer_config )
166240
167241 print ("Loading torch weights into transformer.." , flush = True )
168242 transformer .load_state_dict (torch_state_dict )
0 commit comments