@@ -30,54 +30,17 @@ def rename_for_ltx2_transformer(key):
3030 key = key .replace ("patchify_proj" , "proj_in" )
3131 key = key .replace ("audio_patchify_proj" , "audio_proj_in" )
3232 key = key .replace ("norm_final" , "norm_out" )
33-
34- # Handle scale_shift_table
35- # PyTorch: adaLN_modulation.1.weight/bias -> scale_shift_table
36- # rename_key changes adaLN_modulation.1 -> adaLN_modulation_1
3733 if "adaLN_modulation_1" in key :
3834 key = key .replace ("adaLN_modulation_1" , "scale_shift_table" )
3935
4036 if "caption_modulator_1" in key :
4137 key = key .replace ("caption_modulator_1" , "video_a2v_cross_attn_scale_shift_table" )
42-
43- # Audio caption modulator?
44- # Checkpoint: audio_caption_modulator.1.weight (Guessing name)
45- # Let's inspect checkpoint keys for clues if this guess fails.
4638 if "audio_caption_modulator_1" in key :
4739 key = key .replace ("audio_caption_modulator_1" , "audio_a2v_cross_attn_scale_shift_table" )
48-
49- # Handle audio_caption_projection
50- # Checkpoint: audio_caption_projection.linear_1.weight
51- # Flax: audio_caption_projection.linear_1.kernel
52- # rename_key_and_reshape_tensor catches 'weight' -> 'kernel', but maybe something else renaming it?
53- # No explicit rename needed if it's already linear_1/linear_2 unless name mismatch.
54-
55- # Handle global norms (norm_out, audio_norm_out)
56- # Checkpoint: norm_final -> norm_out (already handled)
57- # Checkpoint also has audio_norm_final -> audio_norm_out?
5840 if "audio_norm_final" in key :
5941 key = key .replace ("audio_norm_final" , "audio_norm_out" )
60-
61- # Handle time_embed/audio_time_embed
62- # Checkpoint: time_embed.emb.timestep_embedder.linear_1.weight
63- # Flax: time_embed.emb.timestep_embedder.linear_1.kernel
64- # If checkpoint uses different name structure?
65- # time_embed.emb.timestep_embedder -> time_embed.emb.timestep_embedder (seems OK)
66-
67- # Handle av_cross_attn...
68- # These seem fine in name but verify if they are Linear or Conv? Linear.
69-
70-
71-
72- # Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
73-
74- # Handle audio_ff.net_0.proj -> audio_ff.net_0
75- # Also handle ff.net_0.proj -> ff.net_0
7642 if ("audio_ff" in key or "ff" in key ) and "proj" in key :
7743 key = key .replace (".proj" , "" )
78-
79- # Handle to_out.0 -> to_out for LTX2Attention
80- # rename_key changes to_out.0 -> to_out_0
8144 if "to_out_0" in key :
8245 key = key .replace ("to_out_0" , "to_out" )
8346
@@ -110,39 +73,24 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
11073 pass
11174
11275 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
113-
114- # Check if we got 'kernel' but expected 'scale' (common for scanned layers where shape check fails)
115- # Also check 'weight' because rename_key might not have converted it to kernel if it wasn't a known Linear
11676 flax_key_str = [str (k ) for k in flax_key ]
11777
11878 if flax_key_str [- 1 ] in ["kernel" , "weight" ]:
119- # Try replacing with scale and check if it exists in random_flax_state_dict
12079 temp_key_str = flax_key_str [:- 1 ] + ["scale" ]
12180 temp_key = tuple (temp_key_str ) # Tuple of strings
12281
12382 if temp_key in random_flax_state_dict :
12483 flax_key_str = temp_key_str
12584 pass
126-
127- # RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
128- # Fix scale_shift_table mapping if it got 'kernel' appended
12985 if "scale_shift_table" in flax_key_str :
130- # if last is kernel/weight, remove it
13186 if flax_key_str [- 1 ] in ["kernel" , "weight" ]:
13287 flax_key_str .pop ()
13388
134- # Handle audio_norm_out / norm_out bias mapping
135- # If renamed to ('audio_norm_out', 'bias') matches ('audio_norm_out', 'bias') in random_flax_state_dict?
136- # Yes. But if rename_key mapped it differently?
137- # Ensure norm_out/audio_norm_out are preserved.
138-
139- # Helper to replace last occurrence
14089 def replace_suffix (lst , old , new ):
14190 if lst and lst [- 1 ] == old :
14291 lst [- 1 ] = new
14392 return lst
14493
145- # LTX-2 uses to_q, to_k, to_v, to_out, NOT query, key, value, proj_attn
14694 if "transformer_blocks" in flax_key_str :
14795 if flax_key_str [- 1 ] == "query" :
14896 flax_key_str [- 1 ] = "to_q"
@@ -157,71 +105,41 @@ def replace_suffix(lst, old, new):
157105
158106 flax_key = tuple (flax_key_str )
159107
160- # Explicit fixes for LTX-2
161- # 1. Linear Projection Weights: weight -> kernel
162- # Keys: audio_caption_projection, caption_projection, audio_proj_in, proj_in, audio_proj_out, proj_out, time_embed, audio_time_embed
163- # Also av_cross_attn lines.
164108 if flax_key [- 1 ] == "weight" :
165- # Check if we should rename to kernel
166- # Heuristic: if parent has 'linear' or 'proj' or 'time_embed' or matches known list
167109 parent = flax_key [- 2 ] if len (flax_key ) > 1 else ""
168110 grandparent = flax_key [- 3 ] if len (flax_key ) > 2 else ""
169111
170112 should_be_kernel = False
171113 if "linear" in parent or "proj" in parent or "proj" in grandparent :
172114 should_be_kernel = True
173115 if "time_embed" in flax_key [0 ] or "cross_attn" in flax_key [0 ]:
174- if "linear" in parent or "emb" in parent : # time_embed.linear
116+ if "linear" in parent or "emb" in parent :
175117 should_be_kernel = True
176118
177- # Exception: norm weights are scale, handled below
178119 if "norm" in parent :
179120 should_be_kernel = False
180121
181122 if should_be_kernel :
182123 flax_key = flax_key [:- 1 ] + ("kernel" ,)
183124
184- # 2. Norm Weights: weight -> scale
185- # Keys: norm_k, norm_q, norm_out, audio_norm_out
186125 if flax_key [- 1 ] == "weight" :
187126 if "norm" in flax_key [- 2 ] or "norm" in flax_key [0 ]:
188127 flax_key = flax_key [:- 1 ] + ("scale" ,)
189128
190- # 3. Audio/Video Attention specifics
191- # Checkpoint: attn1.to_q.weight -> Flax: attn1.to_q.kernel
192- # rename_key usually handles this if it sees 'Linear', but here it might miss if valid_prefixes check fails or it's just 'to_q'.
193- # Force q/k/v/out projections to kernel
194129 if flax_key [- 1 ] == "weight" and flax_key [- 2 ] in ["to_q" , "to_k" , "to_v" , "to_out" ]:
195130 flax_key = flax_key [:- 1 ] + ("kernel" ,)
196131
197- # 4. Fix 'to_out.0' -> 'to_out' if it persisted (sometimes rename_key does to_out_0)
198- # Actually my previous fix in rename_for_ltx2_transformer handles string replacement?
199- # Let's double check tuple.
200- # If we have ('to_out', '0', 'weight') -> ('to_out', 'kernel') ?
201- # Flax expects just 'to_out'? No, Flax nnx.Linear is typically a leaf unless wrapped.
202- # LTX2Attention defines: self.to_out = nnx.Linear(...)
203- # So it expects ('to_out', 'kernel').
204- # Checkpoint has: to_out.0.weight and to_out.0.bias
205- # This implies it was a Sequential or List in PyTorch?
206- # If we see '0' in key, drop it if it's 'to_out'.
207132 if len (flax_key ) >= 2 and flax_key [- 2 ] == "0" and flax_key [- 3 ] == "to_out" :
208- # ('to_out', '0', 'kernel') -> ('to_out', 'kernel')
209133 flax_key = flax_key [:- 3 ] + ("to_out" , flax_key [- 1 ])
210134
211- # 5. Fix norm_k/q becoming scale
212- # Already handled by rule 2?
213- # Checkpoint: attn1.norm_k.weight
214- # Rule 2: parent 'norm_k' has 'norm' -> becomes scale. Correct.
215-
216- flax_key_str = [str (k ) for k in flax_key ] # Update str list for final check
135+ flax_key_str = [str (k ) for k in flax_key ]
217136 flax_key = _tuple_str_to_int (flax_key )
218137
219138 if scan_layers and block_index is not None :
220139 if "transformer_blocks" in flax_key :
221140 if flax_key in flax_state_dict :
222141 new_tensor = flax_state_dict [flax_key ]
223142 else :
224- # Initialize with correct shape (layers, ...)
225143 new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape , dtype = flax_tensor .dtype )
226144
227145 new_tensor = new_tensor .at [block_index ].set (flax_tensor )
@@ -235,8 +153,6 @@ def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
235153 """
236154 index_file = "diffusion_pytorch_model.safetensors.index.json"
237155 tensors = {}
238-
239- # Try to download index file
240156 try :
241157 index_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = index_file )
242158 with open (index_path , "r" ) as f :
@@ -324,8 +240,6 @@ def load_vae_weights(
324240 subfolder : str = "vae"
325241):
326242 device = jax .local_devices (backend = device )[0 ]
327- # VAE for LTX-2 is likely single file, but safe to use the helper if we wanted general robustness.
328- # But `lightricks/LTX-2` VAE is single file.
329243
330244 filename = "diffusion_pytorch_model.safetensors"
331245 try :
@@ -365,7 +279,6 @@ def load_vae_weights(
365279 resnet_index = None
366280
367281 for i , part in enumerate (pt_tuple_key ):
368- # Check for name_N pattern
369282 if "_" in part and part .split ("_" )[- 1 ].isdigit ():
370283 name = "_" .join (part .split ("_" )[:- 1 ])
371284 idx = int (part .split ("_" )[- 1 ])
@@ -375,7 +288,6 @@ def load_vae_weights(
375288 pt_list .append (str (idx ))
376289 elif name == "upsamplers" :
377290 pt_list .append ("upsampler" )
378- # Skip the index 0 for upsampler as Flax uses singular non-list
379291 elif name in ["down_blocks" , "up_blocks" , "downsamplers" ]:
380292 pt_list .append (name )
381293 pt_list .append (str (idx ))
@@ -385,56 +297,42 @@ def load_vae_weights(
385297 pt_list .append ("upsampler" )
386298 elif part in ["conv1" , "conv2" , "conv" ]:
387299 pt_list .append (part )
388- # Inject 'conv' if it's not already there AND not just added
389300 if i + 1 < len (pt_tuple_key ) and pt_tuple_key [i + 1 ] == "conv" :
390- pass # already has conv
301+ pass
391302 elif pt_list [- 1 ] == "conv" :
392- pass # already has conv
303+ pass
393304 elif len (pt_list ) >= 2 and pt_list [- 2 ] == "conv" :
394305 pass
395306 elif part == "conv" :
396307 pass
397308 else :
398- # If part is conv1/conv2, append 'conv'
399309 pt_list .append ("conv" )
400310 else :
401311 pt_list .append (part )
402312
403313 pt_tuple_key = tuple (pt_list )
404314
405315 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
406- # _tuple_str_to_int might not be needed if we already injected ints, but it's safe
407316 flax_key = _tuple_str_to_int (flax_key )
408317
409- # Allow latents_mean/std
410-
411- # DEBUG
412318 flax_key_str = [str (x ) for x in flax_key ]
413319 if "conv" in flax_key_str or "bias" in flax_key_str :
414- # print(f"DEBUG: VAE Key Map: {pt_tuple_key} -> {flax_key}")
415320 pass
416321
417322 if resnet_index is not None :
418323 if flax_key in flax_state_dict :
419324 current_tensor = flax_state_dict [flax_key ]
420325 else :
421- # Initialize with correct shape from random_flax_state_dict
422- # We must use STRING tuple for lookup in random_flax_state_dict
423326 str_flax_key = tuple ([str (x ) for x in flax_key ])
424327
425328 if str_flax_key in random_flax_state_dict :
426329 target_shape = random_flax_state_dict [str_flax_key ].shape
427330 current_tensor = jnp .zeros (target_shape , dtype = flax_tensor .dtype )
428331 else :
429- # Fallback if key missing (shouldn't happen with correct mapping)
430- # print(f"Warning: Key {str_flax_key} not found in random_flax_state_dict, cannot stack.")
431- current_tensor = flax_tensor # Might fail shape check later
432-
433- # Place the tensor at the correct index
434- # flax_tensor is (..., C), target is (N_resnets, ..., C)
332+ current_tensor = flax_tensor
435333
436334 str_flax_key = tuple ([str (x ) for x in flax_key ])
437- if str_flax_key in random_flax_state_dict : # Only stack if we have a valid target
335+ if str_flax_key in random_flax_state_dict :
438336 current_tensor = current_tensor .at [resnet_index ].set (flax_tensor )
439337 flax_state_dict [flax_key ] = current_tensor
440338 else :
0 commit comments