Skip to content

Commit 118e5d0

Browse files
committed
fix in transformer_ltx2.py
1 parent d868c1a commit 118e5d0

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def init_block(rngs):
807807
# 6. Output layers
808808
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
809809
self.norm_out = nnx.LayerNorm(
810-
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
810+
inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
811811
)
812812
self.proj_out = nnx.Linear(
813813
inner_dim,
@@ -820,7 +820,7 @@ def init_block(rngs):
820820
)
821821

822822
self.audio_norm_out = nnx.LayerNorm(
823-
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
823+
audio_inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
824824
)
825825
self.audio_proj_out = nnx.Linear(
826826
audio_inner_dim,

0 commit comments

Comments
 (0)