@@ -887,23 +887,34 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
887887
888888 wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [0 ], 0 ))
889889 wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [1 ], 1 ))
890-
891- layer_w0 = gmm_fn (
892- x ,
893- w0 ,
894- tiling = wi_tile_size ,
895- weight_gather_axes = wi_gather_axes ,
896- input_buffer_count = wi_input_buffer_count ,
897- combine_scopes = wi_combine_scopes ,
898- )
899- layer_w1 = gmm_fn (
900- x ,
901- w1 ,
902- tiling = wi_tile_size ,
903- weight_gather_axes = wi_gather_axes ,
904- input_buffer_count = wi_input_buffer_count ,
905- combine_scopes = wi_combine_scopes ,
906- )
890+ if config .merge_gating_gmm :
891+ w01 = jnp .concatenate ([w0 , w1 ], axis = - 1 )
892+ layer_w01 = gmm_fn (
893+ x ,
894+ w01 ,
895+ tiling = wi_tile_size ,
896+ weight_gather_axes = wi_gather_axes ,
897+ input_buffer_count = wi_input_buffer_count ,
898+ combine_scopes = wi_combine_scopes ,
899+ )
900+ layer_w0 , layer_w1 = jnp .split (layer_w01 , 2 , axis = - 1 )
901+ else :
902+ layer_w0 = gmm_fn (
903+ x ,
904+ w0 ,
905+ tiling = wi_tile_size ,
906+ weight_gather_axes = wi_gather_axes ,
907+ input_buffer_count = wi_input_buffer_count ,
908+ combine_scopes = wi_combine_scopes ,
909+ )
910+ layer_w1 = gmm_fn (
911+ x ,
912+ w1 ,
913+ tiling = wi_tile_size ,
914+ weight_gather_axes = wi_gather_axes ,
915+ input_buffer_count = wi_input_buffer_count ,
916+ combine_scopes = wi_combine_scopes ,
917+ )
907918 layer_w0 = jax .ad_checkpoint .checkpoint_name (layer_w0 , "mlpwi_0" )
908919 layer_w1 = jax .ad_checkpoint .checkpoint_name (layer_w1 , "mlpwi_1" )
909920 intermediate_layer = jax .nn .silu (layer_w0 ) * layer_w1
0 commit comments