1-
21import os
32import json
43import torch
@@ -28,79 +27,73 @@ def rename_for_ltx2_transformer(key):
2827 """
2928 Renames Diffusers LTX-2 keys to MaxDiffusion Flax LTX-2 keys.
3029 """
31- # General replacements
3230 key = key .replace ("patchify_proj" , "proj_in" )
3331 key = key .replace ("audio_patchify_proj" , "audio_proj_in" )
34- key = key .replace ("transformer_blocks" , "transformer_blocks" ) # kept same
35-
36- # AdaLN / Timestep Embed handling
37- # Diffusers uses: time_embed, audio_time_embed, av_cross_attn_...
38- # These match Flax implementation names mostly.
39-
40- # Attention QK Norms -> Flax uses "norm_q", "norm_k" (Diffusers often uses q_norm, k_norm but conversion script mapped them to norm_q/norm_k already?
41- # Wait, the conversion script maps *from* original *to* Diffusers.
42- # If loading Diffusers checkpoint, we should expect "norm_q", "norm_k" if that's what Diffusers uses.
43- # Checking conversion script: LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT maps "q_norm" -> "norm_q".
44- # So Diffusers likely uses "norm_q".
45-
46- # Handle "weight" -> "kernel" for Linear/Conv layers is done in rename_key_and_reshape_tensor
47- # checking rename_key_and_reshape_tensor: it handles "weight" -> "kernel" for linear/conv.
48-
49- # Specific LTX-2 nested renaming
50- # Diffusers: transformer_blocks.0.attn1.to_q.weight
51- # Flax: transformer_blocks.layers.0.attn1.query.kernel (if scanned)
52-
53- # rename_key_and_reshape_tensor handles:
54- # to_q -> query
55- # to_k -> key
56- # to_v -> value
57- # to_out.0 -> proj_attn
58-
59- # We might need to handle specific mismatches if any.
60-
61- # The "scale" vs "weight" for LayerNorm is also handled in rename_key_and_reshape_tensor
62- # BUT only if it detects "norm" in key.
63-
64- # LTX2AdaLayerNormSingle usually has "linear" which is a Linear layer.
65-
32+ key = key .replace ("transformer_blocks" , "transformer_blocks" )
6633 return key
6734
6835
6936def get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers , num_layers = 48 ):
7037 if scan_layers :
7138 if "transformer_blocks" in pt_tuple_key :
72- # transformer_blocks.0.attn1... -> transformer_blocks.layers.attn1...
73- # We need to extract the block index
74- new_key = ("transformer_blocks" ,) + pt_tuple_key [2 :] # removing index
39+ new_key = ("transformer_blocks" ,) + pt_tuple_key [2 :]
7540 block_index = int (pt_tuple_key [1 ])
7641 pt_tuple_key = new_key
7742
78- # For scanned layers, we need to locate the param in the huge stacked tensor
79- # But wait, rename_key_and_reshape_tensor takes the *modified* pt_tuple_key?
80- # No, it takes the original one usually to check against random_flax_state_dict.
81- # But here we are constructing it.
82-
8343 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
84-
85- # Custom cleaning after generic rename
86- # e.g. converting "weight" to "value" for Params if needed, though they usually just take array.
87-
8844 flax_key = _tuple_str_to_int (flax_key )
8945
9046 if scan_layers :
9147 if "transformer_blocks" in flax_key :
92- # We need to stack correct index
9348 if flax_key in flax_state_dict :
9449 new_tensor = flax_state_dict [flax_key ]
9550 else :
96- # Initialize with zeros of shape (num_layers, ...) + tensor.shape
9751 new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape , dtype = flax_tensor .dtype )
98-
9952 new_tensor = new_tensor .at [block_index ].set (flax_tensor )
10053 flax_tensor = new_tensor
10154
10255 return flax_key , flax_tensor
10356
57+ def load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device ):
58+ """
59+ Loads weights from a sharded safetensors checkpoint.
60+ """
61+ index_file = "diffusion_pytorch_model.safetensors.index.json"
62+ tensors = {}
63+
64+ # Try to download index file
65+ try :
66+ index_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = index_file )
67+ with open (index_path , "r" ) as f :
68+ index_data = json .load (f )
69+ weight_map = index_data ["weight_map" ]
70+ shards = set (weight_map .values ())
71+
72+ for shard_file in shards :
73+ shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = shard_file )
74+ with safe_open (shard_path , framework = "pt" ) as f :
75+ for k in f .keys ():
76+ tensors [k ] = torch2jax (f .get_tensor (k ))
77+ except Exception :
78+ # Fallback to single file
79+ filename = "diffusion_pytorch_model.safetensors"
80+ try :
81+ ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
82+ except Exception :
83+ filename = "diffusion_pytorch_model.bin"
84+ ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
85+
86+ if filename .endswith (".safetensors" ):
87+ with safe_open (ckpt_path , framework = "pt" ) as f :
88+ for k in f .keys ():
89+ tensors [k ] = torch2jax (f .get_tensor (k ))
90+ else :
91+ loaded_state_dict = torch .load (ckpt_path , map_location = "cpu" )
92+ for k , v in loaded_state_dict .items ():
93+ tensors [k ] = torch2jax (v )
94+
95+ return tensors
96+
10497def load_transformer_weights (
10598 pretrained_model_name_or_path : str ,
10699 eval_shapes : dict ,
@@ -111,57 +104,25 @@ def load_transformer_weights(
111104 subfolder : str = "transformer" ,
112105):
113106 device = jax .local_devices (backend = device )[0 ]
114-
115- # Determine if local or hub
116- filename = "diffusion_pytorch_model.safetensors"
117- local_files = False
118- if os .path .isdir (pretrained_model_name_or_path ):
119- ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
120- if not os .path .isfile (ckpt_path ):
121- # Try .bin just in case
122- filename = "diffusion_pytorch_model.bin"
123- ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
124- if not os .path .isfile (ckpt_path ):
125- raise FileNotFoundError (f"File { ckpt_path } not found for local directory." )
126- local_files = True
127- elif hf_download :
128- try :
129- ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
130- except Exception :
131- filename = "diffusion_pytorch_model.bin"
132- ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
133-
134107 max_logging .log (f"Load and port { pretrained_model_name_or_path } { subfolder } on { device } " )
135108
136109 with jax .default_device (device ):
137- tensors = {}
138- if filename .endswith (".safetensors" ):
139- with safe_open (ckpt_path , framework = "pt" ) as f :
140- for k in f .keys ():
141- tensors [k ] = torch2jax (f .get_tensor (k ))
142- else : # bin/pt
143- loaded_state_dict = torch .load (ckpt_path , map_location = "cpu" )
144- for k , v in loaded_state_dict .items ():
145- tensors [k ] = torch2jax (v )
110+ # Support sharded loading
111+ tensors = load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device )
146112
147113 flax_state_dict = {}
148114 cpu = jax .local_devices (backend = "cpu" )[0 ]
149115 flattened_dict = flatten_dict (eval_shapes )
150116
151- # Create random state dict with string keys for matching
152117 random_flax_state_dict = {}
153118 for key in flattened_dict :
154- # Convert all ints to strings in key tuple
155119 string_tuple = tuple ([str (item ) for item in key ])
156120 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
157121
158122 for pt_key , tensor in tensors .items ():
159123 renamed_pt_key = rename_key (pt_key )
160124 renamed_pt_key = rename_for_ltx2_transformer (renamed_pt_key )
161125
162- # Handling specific replacements that `rename_key` might miss or `rename_for_ltx2` specifically targets
163- # The `scan_layers` handling requires splitting the key differently if needed.
164-
165126 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
166127
167128 flax_key , flax_tensor = get_key_and_value (
@@ -176,7 +137,6 @@ def load_transformer_weights(
176137 jax .clear_caches ()
177138 return flax_state_dict
178139
179-
180140def load_vae_weights (
181141 pretrained_model_name_or_path : str ,
182142 eval_shapes : dict ,
@@ -185,22 +145,16 @@ def load_vae_weights(
185145 subfolder : str = "vae"
186146):
187147 device = jax .local_devices (backend = device )[0 ]
188- filename = "diffusion_pytorch_model.safetensors"
148+ # VAE for LTX-2 is likely single file, but safe to use the helper if we wanted general robustness.
149+ # But `lightricks/LTX-2` VAE is single file.
189150
190- if os .path .isdir (pretrained_model_name_or_path ):
191- ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
192- if not os .path .isfile (ckpt_path ):
193- filename = "diffusion_pytorch_model.bin"
194- ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
195- if not os .path .isfile (ckpt_path ):
196- raise FileNotFoundError (f"File { ckpt_path } not found for local directory." )
197- elif hf_download :
198- try :
151+ filename = "diffusion_pytorch_model.safetensors"
152+ try :
199153 ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
200- except Exception :
154+ except Exception :
201155 filename = "diffusion_pytorch_model.bin"
202156 ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
203-
157+
204158 max_logging .log (f"Load and port { pretrained_model_name_or_path } VAE on { device } " )
205159
206160 with jax .default_device (device ):
@@ -218,83 +172,32 @@ def load_vae_weights(
218172 cpu = jax .local_devices (backend = "cpu" )[0 ]
219173 flattened_eval = flatten_dict (eval_shapes )
220174
221- # Build random state dict for shape checking/key matching help
222- # VAE usually doesn't need scan layers logic for mapping (unless we implement scanned VAE similar to Transformer, but autoencoder_kl_ltx2.py uses scan but keys seem compatible with standard diffusers structure if mapped correctly)
223- # Wait, `autoencoder_kl_ltx2.py` DOES use scan for `resnets`!
224- # See `create_resnets` and `resnet_scan_fn`.
225- # So we DO need scan layer handling for VAE if we want to load it into that structure.
226- # The VAE resnets are scanned over `num_layers`.
227-
228- # Mapping Diffusers VAE to Scanned VAE:
229- # Diffusers: down_blocks.0.resnets.0 ...
230- # Flax Scanned: down_blocks.0.resnets.layers.0 ... (if mapped that way)
231- # OR: down_blocks.0.resnets -> (num_layers, ...) tensor if we stack them.
232-
233- # Let's check `autoencoder_kl_ltx2.py` again.
234- # `self.resnets = create_resnets(rngs)` where `create_resnets` is vmapped.
235- # This creates params with a leading dimension = num_layers.
236- # So we need to stack Diffusers resnets weights.
237-
238- # We need a custom `get_key_and_value` for VAE or modify the existing one to handle VAE blocks too.
239- pass
240-
241- # For now, let's just write the loading logic and we might need to iterate and fix VAE scanning logic if it fails validation.
242- # Ideally we use `rename_key_and_reshape_tensor` heavily.
243-
244175 random_flax_state_dict = {}
245176 for key in flattened_eval :
246177 string_tuple = tuple ([str (item ) for item in key ])
247178 random_flax_state_dict [string_tuple ] = flattened_eval [key ]
248-
179+
249180 for pt_key , tensor in tensors .items ():
250181 renamed_pt_key = rename_key (pt_key )
251-
252- # VAE specific renames
253- renamed_pt_key = renamed_pt_key .replace ( "mid_block.resnets." , "mid_block.resnets.layers ." )
254- renamed_pt_key = renamed_pt_key . replace ( "down_blocks." , "down_blocks." ) # keeping same
255- # Need to handle resnets.0 -> resnets.layers.0 etc if we want to be explicit, or rely on scanning logic.
256-
257- # If we use scan, we need to stack "resnets.0", "resnets.1" etc into "resnets" tensor.
258- # The logic in `get_key_and_value` handles `transformer_blocks` scanning. We should extend it for VAE `resnets`.
259-
260- # Actually, `autoencoder_kl_ltx2.py` VAE scanning is slightly different.
261- # It scans over `resnets`.
262- # Diffusers has `down_blocks.0.resnets.0`, `down_blocks.0.resnets.1`.
263- # We need to stack these.
264-
182+ if ".resnets." in renamed_pt_key :
183+ # pattern: resnets.0 -> resnets_0
184+ parts = renamed_pt_key .split ( " ." )
185+ new_parts = []
186+ i = 0
187+ while i < len ( parts ):
188+ if parts [ i ] == "resnets" and i + 1 < len ( parts ) and parts [ i + 1 ]. isdigit ():
189+ new_parts . append ( f"resnets_ { parts [ i + 1 ] } " )
190+ i += 2
191+ else :
192+ new_parts . append ( parts [ i ])
193+ i += 1
194+ renamed_pt_key = "." . join ( new_parts )
195+
265196 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
266197
267- # Let's add VAE scanning logic here or in a helper
268- # Identifying keys to stack: keys containing `resnets._`
269-
270- # Simplified VAE Loading (non-scanned or manual stacking):
271- # If `rename_key_and_reshape_tensor` expects exact matching, we might have trouble if keys are "resnets.0" but flax expects "resnets" (stacked).
272-
273- # I will implement a check: if key has `resnets.N`, we try to stack it.
274-
275198 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
276-
277- # If it didn't match immediately, check if it's a resnet layer that needs stacking
278- # This part is tricky without strictly knowing num_layers per block.
279- # But we can infer or just load individually if Flax wasn't scanned?
280- # The Flax code definitely uses scan.
281-
282- # HACK: For VAE, let's assume we might need to manually stack or map to specific indices if `rename_key_and_reshape_tensor` didn't catch it.
283- # But for now, let's just use `rename_key_and_reshape_tensor` and `validate_flax_state_dict` will tell us what failed.
284-
285199 flax_key = _tuple_str_to_int (flax_key )
286200
287- # Manual VAE Stacking logic if needed:
288- # if "resnets" in flax_key and generic match failed...
289-
290- # Let's rely on `validate_flax_state_dict` to debug VAE mapping in the test phase if it's complex.
291- # But I should probably add the `resnets` -> `resnets.layers` replacement to be safe?
292- # Wait, if I replace `resnets.0` with `resnets.layers.0`, and Flax expects `resnets` (stacked), it still won't match.
293- # Flax `nnx.vmap` with `transform_metadata={nnx.PARTITION_NAME: "layers"}` usually expects a stacked axis.
294- # The parameter key in `state_dict` for a vmapped layer often depends on how it's stored.
295- # In NNX/Flax, it might be stored as `resnets.layers`? No, usually just `resnets` with an extra dim?
296- # Or `resnets.layers.kernel` if it kept the name.
297-
298201 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
299202
300203 validate_flax_state_dict (eval_shapes , flax_state_dict )
0 commit comments