@@ -208,19 +208,21 @@ def __init__(
208208 weights_dtype = weights_dtype
209209 )
210210
211- scale_rng , init_rng = nnx .split_rngs (rngs , "params" , "initialization" )
211+
212+ key = rngs .params ()
213+ k1 , k2 , k3 , k4 = jax .random .split (key , 4 )
212214
213215 self .scale_shift_table = nnx .Param (
214- jax .random .normal (init_rng () , (6 , self .dim ), dtype = weights_dtype ) / jnp .sqrt (self .dim )
216+ jax .random .normal (k1 , (6 , self .dim ), dtype = weights_dtype ) / jnp .sqrt (self .dim )
215217 )
216218 self .audio_scale_shift_table = nnx .Param (
217- jax .random .normal (init_rng () , (6 , audio_dim ), dtype = weights_dtype ) / jnp .sqrt (audio_dim )
219+ jax .random .normal (k2 , (6 , audio_dim ), dtype = weights_dtype ) / jnp .sqrt (audio_dim )
218220 )
219221 self .video_a2v_cross_attn_scale_shift_table = nnx .Param (
220- jax .random .normal (init_rng () , (5 , self .dim ), dtype = weights_dtype )
222+ jax .random .normal (k3 , (5 , self .dim ), dtype = weights_dtype )
221223 )
222224 self .audio_a2v_cross_attn_scale_shift_table = nnx .Param (
223- jax .random .normal (init_rng () , (5 , audio_dim ), dtype = weights_dtype )
225+ jax .random .normal (k4 , (5 , audio_dim ), dtype = weights_dtype )
224226 )
225227
226228 def __call__ (
0 commit comments