@@ -1032,23 +1032,24 @@ def scan_fn(carry, block):
10321032 )
10331033
10341034 # 6. Output layers
1035- scale_shift_values = jnp .expand_dims (self .scale_shift_table , axis = (0 , 1 )) + jnp .expand_dims (embedded_timestep , axis = 2 )
1036- shift = scale_shift_values [:, :, 0 , :]
1037- scale = scale_shift_values [:, :, 1 , :]
1035+ with jax .named_scope ("Output Projection & Norm" ):
1036+ scale_shift_values = jnp .expand_dims (self .scale_shift_table , axis = (0 , 1 )) + jnp .expand_dims (embedded_timestep , axis = 2 )
1037+ shift = scale_shift_values [:, :, 0 , :]
1038+ scale = scale_shift_values [:, :, 1 , :]
10381039
1039- hidden_states = self .norm_out (hidden_states )
1040- hidden_states = hidden_states * (1 + scale ) + shift
1041- output = self .proj_out (hidden_states )
1040+ hidden_states = self .norm_out (hidden_states )
1041+ hidden_states = hidden_states * (1 + scale ) + shift
1042+ output = self .proj_out (hidden_states )
10421043
1043- audio_scale_shift_values = jnp .expand_dims (self .audio_scale_shift_table , axis = (0 , 1 )) + jnp .expand_dims (
1044- audio_embedded_timestep , axis = 2
1045- )
1046- audio_shift = audio_scale_shift_values [:, :, 0 , :]
1047- audio_scale = audio_scale_shift_values [:, :, 1 , :]
1044+ audio_scale_shift_values = jnp .expand_dims (self .audio_scale_shift_table , axis = (0 , 1 )) + jnp .expand_dims (
1045+ audio_embedded_timestep , axis = 2
1046+ )
1047+ audio_shift = audio_scale_shift_values [:, :, 0 , :]
1048+ audio_scale = audio_scale_shift_values [:, :, 1 , :]
10481049
1049- audio_hidden_states = self .audio_norm_out (audio_hidden_states )
1050- audio_hidden_states = audio_hidden_states * (1 + audio_scale ) + audio_shift
1051- audio_output = self .audio_proj_out (audio_hidden_states )
1050+ audio_hidden_states = self .audio_norm_out (audio_hidden_states )
1051+ audio_hidden_states = audio_hidden_states * (1 + audio_scale ) + audio_shift
1052+ audio_output = self .audio_proj_out (audio_hidden_states )
10521053
10531054 if not return_dict :
10541055 return (output , audio_output )
0 commit comments