Skip to content

Commit 758d8a4

Browse files
committed
transformer file changed
1 parent 4534be4 commit 758d8a4

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
115115
# Also check 'weight' because rename_key might not have converted it to kernel if it wasn't a known Linear
116116
flax_key_str = [str(k) for k in flax_key]
117117

118+
# DEBUG: Check specific keys
119+
if "norm_k" in flax_key_str or "audio_caption_projection" in flax_key_str:
120+
print(f"DEBUG: get_key_and_value mapping: {pt_tuple_key} -> {flax_key_str}")
121+
118122
if flax_key_str[-1] in ["kernel", "weight"]:
119123
# Try replacing with scale and check if it exists in random_flax_state_dict
120124
temp_key_str = flax_key_str[:-1] + ["scale"]

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def init_block(rngs):
807807
# 6. Output layers
808808
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
809809
self.norm_out = nnx.LayerNorm(
810-
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
810+
inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
811811
)
812812
self.proj_out = nnx.Linear(
813813
inner_dim,
@@ -820,7 +820,7 @@ def init_block(rngs):
820820
)
821821

822822
self.audio_norm_out = nnx.LayerNorm(
823-
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
823+
audio_inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
824824
)
825825
self.audio_proj_out = nnx.Linear(
826826
audio_inner_dim,

0 commit comments

Comments
 (0)