@@ -160,6 +160,64 @@ def replace_suffix(lst, old, new):
160160 flax_key_str [- 2 ] = "to_out"
161161
162162 flax_key = tuple (flax_key_str )
163+
164+ # Explicit fixes for LTX-2
165+ # 1. Linear Projection Weights: weight -> kernel
166+ # Keys: audio_caption_projection, caption_projection, audio_proj_in, proj_in, audio_proj_out, proj_out, time_embed, audio_time_embed
167+ # Also av_cross_attn lines.
168+ if flax_key [- 1 ] == "weight" :
169+ # Check if we should rename to kernel
170+ # Heuristic: if parent has 'linear' or 'proj' or 'time_embed' or matches known list
171+ parent = flax_key [- 2 ] if len (flax_key ) > 1 else ""
172+ grandparent = flax_key [- 3 ] if len (flax_key ) > 2 else ""
173+
174+ should_be_kernel = False
175+ if "linear" in parent or "proj" in parent or "proj" in grandparent :
176+ should_be_kernel = True
177+ if "time_embed" in flax_key [0 ] or "cross_attn" in flax_key [0 ]:
178+ if "linear" in parent or "emb" in parent : # time_embed.linear
179+ should_be_kernel = True
180+
181+ # Exception: norm weights are scale, handled below
182+ if "norm" in parent :
183+ should_be_kernel = False
184+
185+ if should_be_kernel :
186+ flax_key = flax_key [:- 1 ] + ("kernel" ,)
187+
188+ # 2. Norm Weights: weight -> scale
189+ # Keys: norm_k, norm_q, norm_out, audio_norm_out
190+ if flax_key [- 1 ] == "weight" :
191+ if "norm" in flax_key [- 2 ] or "norm" in flax_key [0 ]:
192+ flax_key = flax_key [:- 1 ] + ("scale" ,)
193+
194+ # 3. Audio/Video Attention specifics
195+ # Checkpoint: attn1.to_q.weight -> Flax: attn1.to_q.kernel
196+ # rename_key usually handles this if it sees 'Linear', but here it might miss if valid_prefixes check fails or it's just 'to_q'.
197+ # Force q/k/v/out projections to kernel
198+ if flax_key [- 1 ] == "weight" and flax_key [- 2 ] in ["to_q" , "to_k" , "to_v" , "to_out" ]:
199+ flax_key = flax_key [:- 1 ] + ("kernel" ,)
200+
201+ # 4. Fix 'to_out.0' -> 'to_out' if it persisted (sometimes rename_key does to_out_0)
202+ # Actually my previous fix in rename_for_ltx2_transformer handles string replacement?
203+ # Let's double check tuple.
204+ # If we have ('to_out', '0', 'weight') -> ('to_out', 'kernel') ?
205+ # Flax expects just 'to_out'? No, Flax nnx.Linear is typically a leaf unless wrapped.
206+ # LTX2Attention defines: self.to_out = nnx.Linear(...)
207+ # So it expects ('to_out', 'kernel').
208+ # Checkpoint has: to_out.0.weight and to_out.0.bias
209+ # This implies it was a Sequential or List in PyTorch?
210+ # If we see '0' in key, drop it if it's 'to_out'.
211+ if len (flax_key ) >= 2 and flax_key [- 2 ] == "0" and flax_key [- 3 ] == "to_out" :
212+ # ('to_out', '0', 'kernel') -> ('to_out', 'kernel')
213+ flax_key = flax_key [:- 3 ] + ("to_out" , flax_key [- 1 ])
214+
215+ # 5. Fix norm_k/q becoming scale
216+ # Already handled by rule 2?
217+ # Checkpoint: attn1.norm_k.weight
218+ # Rule 2: parent 'norm_k' has 'norm' -> becomes scale. Correct.
219+
220+ flax_key_str = [str (k ) for k in flax_key ] # Update str list for final check
163221 flax_key = _tuple_str_to_int (flax_key )
164222
165223 if scan_layers and block_index is not None :
0 commit comments