3535from MaxText .layers import attentions , linears , nnx_wrappers , quantizations
3636from MaxText .layers .initializers import NdInitializer , default_bias_init , nd_dense_init , variable_to_logically_partitioned
3737import numpy as np
38+ import qwix .pallas as qpl
3839import tokamax
3940
4041set_xla_metadata = xla_metadata .set_xla_metadata
@@ -792,7 +793,7 @@ def sparse_matmul(
792793 ):
793794 """Perform sparse matrix multiplication of inputs and Experts."""
794795
795- def gmm (inputs , kernel , tiling , group_sizes , expert_assignments ):
796+ def gmm (inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes ):
796797 pad_length = self .config .wi_tile_fwd_batch_seq
797798 hs_shape = inputs .shape
798799 # pad length is the 1st dimension of tiling size in gmm call
@@ -830,7 +831,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
830831 rhs_quantize_dtype = rhs_quantize_dtype ,
831832 use_qwix_quantization = self .config .use_qwix_quantization ,
832833 use_tokamax_backend = self .config .use_tokamax_gmm ,
833- is_fsdp_shard_on_exp = self . config . fsdp_shard_on_exp ,
834+ weight_gather_axes = weight_gather_axes ,
834835 )
835836 else :
836837 output = tokamax .ragged_dot (
@@ -853,7 +854,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
853854 rhs_quantize_dtype = rhs_quantize_dtype ,
854855 use_qwix_quantization = self .config .use_qwix_quantization ,
855856 use_tokamax_backend = self .config .use_tokamax_gmm ,
856- is_fsdp_shard_on_exp = self . config . fsdp_shard_on_exp ,
857+ weight_gather_axes = weight_gather_axes ,
857858 )
858859 else :
859860 rhs_inputs = kernel
@@ -935,12 +936,15 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
935936
936937 # w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use
937938 # mlp_no_fsdp axis
939+ weight_gather = False
938940 if self .config .fsdp_shard_on_exp :
939- if self .config .quantization :
940- # special sharding when quantization is enabled with fsdp_shard_on_exp
941+ quantization_rule = qpl .get_current_rule ("gmm" )
942+ if quantization_rule and quantization_rule .weight_calibration_method .startswith ("fixed" ):
943+ # special sharding when using static scaling for weights in quantization with fsdp_shard_on_exp
941944 w0_pspec = nn .logical_to_mesh_axes (self .wi_kernel_axes )
942945 w1_pspec = nn .logical_to_mesh_axes (self .wi_kernel_axes )
943946 wo_pspec = nn .logical_to_mesh_axes (self .wo_kernel_axes )
947+ weight_gather = True
944948 else :
945949 # special sharding for dsv3 to remove overhead between gmm/AG
946950 w0_pspec = nn .logical_to_mesh_axes (("embed_tensor_transpose" , None , "mlp_no_fsdp" ))
@@ -1069,7 +1073,25 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10691073
10701074 if self .config .mlp_bias :
10711075 w0_bias , w1_bias , wo_bias = self .transform_bias (selected_experts , w0_bias , w1_bias , wo_bias )
1072-
1076+ def get_active_sharding_axes (pspec_dim_axes , tensor_dim_index ):
1077+ if pspec_dim_axes is None : return []
1078+ axes = (pspec_dim_axes ,) if isinstance (pspec_dim_axes , str ) else pspec_dim_axes
1079+ active = []
1080+ for ax in axes :
1081+ if ax and self .mesh .shape .get (ax , 1 ) > 1 :
1082+ active .append ((ax , tensor_dim_index ))
1083+ return active
1084+ wi_gather_axes = []
1085+ wo_gather_axes = []
1086+
1087+ if weight_gather :
1088+ # wi [Experts, In, Hidden] -> Gather Exp(0) and Hidden(2)
1089+ wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [0 ], 0 ))
1090+ wi_gather_axes .extend (get_active_sharding_axes (w0_pspec [2 ], 2 ))
1091+
1092+ # wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1093+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [0 ], 0 ))
1094+ wo_gather_axes .extend (get_active_sharding_axes (wo_pspec [1 ], 1 ))
10731095 gmm_fn = functools .partial (
10741096 gmm ,
10751097 group_sizes = group_sizes ,
@@ -1097,22 +1119,22 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10971119 self .config .wo_tile_drhs_embed_dim ,
10981120 self .config .wo_tile_drhs_mlp_dim ,
10991121 )
1100- layer_w0 = gmm_fn (x , w0 , tiling = wi_tile_size )
1122+ layer_w0 = gmm_fn (x , w0 , tiling = wi_tile_size , weight_gather_axes = wi_gather_axes )
11011123 if self .get_tensor_transpose_parallelism_size () > 1 :
11021124 layer_w0 = jax .lax .psum (layer_w0 , "tensor_transpose" )
11031125 if self .config .mlp_bias :
11041126 layer_w0 = layer_w0 + w0_bias
11051127 layer_w0 = adc .checkpoint_name (layer_w0 , "mlpwi_0" )
11061128
1107- layer_w1 = gmm_fn (x , w1 , tiling = wi_tile_size )
1129+ layer_w1 = gmm_fn (x , w1 , tiling = wi_tile_size , weight_gather_axes = wi_gather_axes )
11081130 if self .get_tensor_transpose_parallelism_size () > 1 :
11091131 layer_w1 = jax .lax .psum (layer_w1 , "tensor_transpose" )
11101132 if self .config .mlp_bias :
11111133 layer_w1 = layer_w1 + w1_bias
11121134 layer_w1 = adc .checkpoint_name (layer_w1 , "mlpwi_1" )
11131135 intermediate_layer = self .apply_ffn_activation (layer_w0 , layer_w1 )
11141136
1115- intermediate_output = gmm_fn (intermediate_layer , wo , tiling = wo_tile_size )
1137+ intermediate_output = gmm_fn (intermediate_layer , wo , tiling = wo_tile_size , weight_gather_axes = wo_gather_axes )
11161138 if self .get_tensor_parallelism_size () > 1 :
11171139 intermediate_output = jax .lax .psum_scatter (intermediate_output , "tensor" , scatter_dimension = 1 , tiled = True )
11181140 if self .config .mlp_bias :
0 commit comments