@@ -158,14 +158,6 @@ def __init__(
158158 rope_type = rope_type ,
159159 )
160160
161- # Scale Shift Tables
162- self .scale_shift_table = nnx .Param (jax .random .normal (rngs .params (), (6 , self .dim ), dtype = weights_dtype ) / jnp .sqrt (self .dim ))
163- self .audio_scale_shift_table = nnx .Param (
164- jax .random .normal (rngs .params (), (6 , audio_dim ), dtype = weights_dtype ) / jnp .sqrt (audio_dim )
165- )
166- self .video_a2v_cross_attn_scale_shift_table = nnx .Param (jax .random .normal (rngs .params (), (5 , self .dim ), dtype = weights_dtype ))
167- self .audio_a2v_cross_attn_scale_shift_table = nnx .Param (jax .random .normal (rngs .params (), (5 , audio_dim ), dtype = weights_dtype ))
168-
169161 # 2. Prompt Cross-Attention
170162 self .norm2 = nnx .RMSNorm (
171163 self .dim ,
@@ -815,7 +807,7 @@ def init_block(rngs):
815807 # 6. Output layers
816808 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
817809 self .norm_out = nnx .LayerNorm (
818- inner_dim , epsilon = 1e-6 , use_scale = False , use_bias = False , rngs = rngs , dtype = jnp .float32 , param_dtype = jnp .float32
810+ inner_dim , epsilon = 1e-6 , use_scale = False , rngs = rngs , dtype = jnp .float32 , param_dtype = jnp .float32
819811 )
820812 self .proj_out = nnx .Linear (
821813 inner_dim ,
@@ -828,7 +820,7 @@ def init_block(rngs):
828820 )
829821
830822 self .audio_norm_out = nnx .LayerNorm (
831- audio_inner_dim , epsilon = 1e-6 , use_scale = False , use_bias = False , rngs = rngs , dtype = jnp .float32 , param_dtype = jnp .float32
823+ audio_inner_dim , epsilon = 1e-6 , use_scale = False , rngs = rngs , dtype = jnp .float32 , param_dtype = jnp .float32
832824 )
833825 self .audio_proj_out = nnx .Linear (
834826 audio_inner_dim ,
0 commit comments