Skip to content

Commit 6b2726c

Browse files
committed
fix
1 parent 303c4f2 commit 6b2726c

3 files changed

Lines changed: 52 additions & 4 deletions

File tree

debug_eval_shapes.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
import jax
3+
import jax.numpy as jnp
4+
from flax import nnx
5+
from maxdiffusion.models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
6+
7+
def debug_eval_shapes():
8+
rngs = nnx.Rngs(0)
9+
transformer = LTX2VideoTransformer3DModel(
10+
rngs=rngs,
11+
in_channels=128,
12+
out_channels=128,
13+
patch_size=1,
14+
patch_size_t=1,
15+
num_attention_heads=4, # Small for speed
16+
attention_head_dim=32,
17+
cross_attention_dim=64,
18+
audio_dim=32,
19+
audio_num_attention_heads=4,
20+
audio_attention_head_dim=8,
21+
audio_cross_attention_dim=32,
22+
num_layers=2, # Small for speed
23+
scan_layers=True
24+
)
25+
26+
state = nnx.state(transformer)
27+
eval_shapes = state.to_pure_dict()
28+
29+
from flax.traverse_util import flatten_dict
30+
flat_shapes = flatten_dict(eval_shapes)
31+
32+
print("--- EVAL SHAPES DEBUG ---")
33+
keys = sorted(list(flat_shapes.keys()))
34+
35+
for k in keys:
36+
k_str = str(k)
37+
if "norm_out" in k_str:
38+
print(f"NORM_OUT: {k}")
39+
if "audio_caption_projection" in k_str:
40+
print(f"AUDIO_CAP_PROJ: {k}")
41+
if "scale_shift_table" in k_str:
42+
print(f"SCALE_SHIFT: {k}")
43+
if "transformer_blocks" in k_str and "audio_to_video_attn" in k_str and "norm_k" in k_str:
44+
print(f"BLOCK_KEY: {k}")
45+
46+
if __name__ == "__main__":
47+
debug_eval_shapes()

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def init_block(rngs):
807807
# 6. Output layers
808808
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
809809
self.norm_out = nnx.LayerNorm(
810-
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
810+
inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
811811
)
812812
self.proj_out = nnx.Linear(
813813
inner_dim,
@@ -820,7 +820,7 @@ def init_block(rngs):
820820
)
821821

822822
self.audio_norm_out = nnx.LayerNorm(
823-
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
823+
audio_inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
824824
)
825825
self.audio_proj_out = nnx.Linear(
826826
audio_inner_dim,

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def test_load_transformer_weights(self):
104104
)
105105

106106
print("Validating Transformer Weights...")
107-
validate_flax_state_dict(eval_shapes, loaded_weights)
107+
from flax.traverse_util import flatten_dict
108+
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
108109
print("Transformer Weights Validated Successfully!")
109110

110111
def test_load_vae_weights(self):
@@ -142,7 +143,7 @@ def test_load_vae_weights(self):
142143
continue
143144
filtered_eval_shapes[k] = v
144145

145-
validate_flax_state_dict(filtered_eval_shapes, loaded_weights)
146+
validate_flax_state_dict(filtered_eval_shapes, flatten_dict(loaded_weights))
146147
print("VAE Weights Validated Successfully!")
147148

148149
if __name__ == "__main__":

0 commit comments

Comments
 (0)