Skip to content

Commit 4c4446c

Browse files
committed
fix
1 parent 2be0be8 commit 4c4446c

3 files changed

Lines changed: 40 additions & 24 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ hardware: 'tpu'
33
skip_jax_distributed_system: False
44
attention: 'flash'
55
attention_sharding_uniform: True
6+
audio_attention_head_dim: 128
67

78
jax_cache_dir: ''
89
weights_dtype: 'bfloat16'

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def rename_for_ltx2_transformer(key):
4040
# This line was redundant, keeping it as a no-op or removing it is fine.
4141
# The instruction implies it should be `return key` at the end.
4242
key = key.replace("transformer_blocks", "transformer_blocks")
43+
44+
# Handle to_out.0 -> to_out for LTX2Attention
45+
if "to_out.0" in key:
46+
key = key.replace("to_out.0", "to_out")
47+
4348
return key
4449

4550

@@ -145,9 +150,16 @@ def load_transformer_weights(
145150
for k in list(tensors.keys())[:20]:
146151
print(k)
147152

153+
148154
print("\nDEBUG: Top 20 keys from Flax Model (eval_shapes):")
149155
for k in list(random_flax_state_dict.keys())[:20]:
150156
print(k)
157+
158+
print("\nDEBUG: Transformer Block keys from Flax Model (eval_shapes):")
159+
for k in list(random_flax_state_dict.keys()):
160+
if "transformer_blocks" in k and "attn1" in k:
161+
print(k)
162+
break
151163

152164
for pt_key, tensor in tensors.items():
153165
renamed_pt_key = rename_key(pt_key)
@@ -211,17 +223,10 @@ def load_vae_weights(
211223
renamed_pt_key = rename_key(pt_key)
212224
if ".resnets." in renamed_pt_key:
213225
# pattern: resnets.0 -> resnets_0
214-
parts = renamed_pt_key.split(".")
215-
new_parts = []
216-
i = 0
217-
while i < len(parts):
218-
if parts[i] == "resnets" and i+1 < len(parts) and parts[i+1].isdigit():
219-
new_parts.append(f"resnets_{parts[i+1]}")
220-
i += 2
221-
else:
222-
new_parts.append(parts[i])
223-
i += 1
224-
renamed_pt_key = ".".join(new_parts)
226+
# We need to capture the number after resnets
227+
import re
228+
# Replace resnets.N with resnets_N
229+
renamed_pt_key = re.sub(r"resnets\.(\d+)", r"resnets_\1", renamed_pt_key)
225230

226231
pt_tuple_key = tuple(renamed_pt_key.split("."))
227232

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,29 @@ def test_load_transformer_weights(self):
2929
pretrained_model_name_or_path = "Lightricks/LTX-2"
3030

3131
with jax.default_device(jax.devices("cpu")[0]):
32-
model = LTX2VideoTransformer3DModel(
33-
rngs=self.rngs,
34-
# Explicitly setting key params to version 2.0 to be safe
35-
in_channels=128,
36-
out_channels=128,
37-
patch_size=1,
38-
patch_size_t=1,
39-
num_attention_heads=32,
40-
attention_head_dim=128,
41-
cross_attention_dim=4096,
42-
num_layers=48,
43-
scan_layers=True
44-
)
32+
self.config = LTX2VideoConfig()
33+
self.config.audio_attention_head_dim = 128 # Match Checkpoint
34+
35+
self.transformer = LTX2VideoTransformer3DModel(
36+
in_channels=self.config.in_channels,
37+
out_channels=self.config.out_channels,
38+
patch_size=self.config.patch_size,
39+
patch_size_t=self.config.patch_size_t,
40+
num_attention_heads=self.config.num_attention_heads,
41+
attention_head_dim=self.config.attention_head_dim,
42+
cross_attention_dim=self.config.cross_attention_dim,
43+
audio_in_channels=self.config.audio_in_channels,
44+
audio_out_channels=self.config.audio_out_channels,
45+
audio_patch_size=self.config.audio_patch_size,
46+
audio_patch_size_t=self.config.audio_patch_size_t,
47+
audio_num_attention_heads=self.config.audio_num_attention_heads,
48+
audio_attention_head_dim=128, # Match Config/Checkpoint
49+
audio_cross_attention_dim=self.config.audio_cross_attention_dim,
50+
num_layers=self.config.num_layers,
51+
scan_layers=True,
52+
param_dtype=jnp.bfloat16,
53+
rngs=nnx.Rngs(0),
54+
)
4555

4656
# Get abstract state (shapes only)
4757
# We need the PyTree structure of parameters

0 commit comments

Comments
 (0)