Skip to content

Commit 7847b19

Browse files
committed
fix
1 parent 758d8a4 commit 7847b19

1 file changed

Lines changed: 58 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,64 @@ def replace_suffix(lst, old, new):
160160
flax_key_str[-2] = "to_out"
161161

162162
flax_key = tuple(flax_key_str)
163+
164+
# Explicit fixes for LTX-2
165+
# 1. Linear Projection Weights: weight -> kernel
166+
# Keys: audio_caption_projection, caption_projection, audio_proj_in, proj_in, audio_proj_out, proj_out, time_embed, audio_time_embed
167+
# Also av_cross_attn lines.
168+
if flax_key[-1] == "weight":
169+
# Check if we should rename to kernel
170+
# Heuristic: if parent has 'linear' or 'proj' or 'time_embed' or matches known list
171+
parent = flax_key[-2] if len(flax_key) > 1 else ""
172+
grandparent = flax_key[-3] if len(flax_key) > 2 else ""
173+
174+
should_be_kernel = False
175+
if "linear" in parent or "proj" in parent or "proj" in grandparent:
176+
should_be_kernel = True
177+
if "time_embed" in flax_key[0] or "cross_attn" in flax_key[0]:
178+
if "linear" in parent or "emb" in parent: # time_embed.linear
179+
should_be_kernel = True
180+
181+
# Exception: norm weights are scale, handled below
182+
if "norm" in parent:
183+
should_be_kernel = False
184+
185+
if should_be_kernel:
186+
flax_key = flax_key[:-1] + ("kernel",)
187+
188+
# 2. Norm Weights: weight -> scale
189+
# Keys: norm_k, norm_q, norm_out, audio_norm_out
190+
if flax_key[-1] == "weight":
191+
if "norm" in flax_key[-2] or "norm" in flax_key[0]:
192+
flax_key = flax_key[:-1] + ("scale",)
193+
194+
# 3. Audio/Video Attention specifics
195+
# Checkpoint: attn1.to_q.weight -> Flax: attn1.to_q.kernel
196+
# rename_key usually handles this if it sees 'Linear', but here it might miss if valid_prefixes check fails or it's just 'to_q'.
197+
# Force q/k/v/out projections to kernel
198+
if flax_key[-1] == "weight" and flax_key[-2] in ["to_q", "to_k", "to_v", "to_out"]:
199+
flax_key = flax_key[:-1] + ("kernel",)
200+
201+
# 4. Fix 'to_out.0' -> 'to_out' if it persisted (sometimes rename_key does to_out_0)
202+
# Actually my previous fix in rename_for_ltx2_transformer handles string replacement?
203+
# Let's double check tuple.
204+
# If we have ('to_out', '0', 'weight') -> ('to_out', 'kernel') ?
205+
# Flax expects just 'to_out'? No, Flax nnx.Linear is typically a leaf unless wrapped.
206+
# LTX2Attention defines: self.to_out = nnx.Linear(...)
207+
# So it expects ('to_out', 'kernel').
208+
# Checkpoint has: to_out.0.weight and to_out.0.bias
209+
# This implies it was a Sequential or List in PyTorch?
210+
# If we see '0' in key, drop it if it's 'to_out'.
211+
if len(flax_key) >= 2 and flax_key[-2] == "0" and flax_key[-3] == "to_out":
212+
# ('to_out', '0', 'kernel') -> ('to_out', 'kernel')
213+
flax_key = flax_key[:-3] + ("to_out", flax_key[-1])
214+
215+
# 5. Fix norm_k/q becoming scale
216+
# Already handled by rule 2?
217+
# Checkpoint: attn1.norm_k.weight
218+
# Rule 2: parent 'norm_k' has 'norm' -> becomes scale. Correct.
219+
220+
flax_key_str = [str(k) for k in flax_key] # Update str list for final check
163221
flax_key = _tuple_str_to_int(flax_key)
164222

165223
if scan_layers and block_index is not None:

0 commit comments

Comments
 (0)