Skip to content

Commit 44039d8

Browse files
Merge pull request #3199 from AI-Hypercomputer:amandaliang
PiperOrigin-RevId: 875485665
2 parents 8a0a215 + 3c4d81d commit 44039d8

3 files changed

Lines changed: 32 additions & 17 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ wo_tile_drhs_buffer_count: 2
214214
wi_combine_scopes: False
215215
wo_combine_scopes: False
216216

217+
merge_gating_gmm: False
218+
217219
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.
218220

219221
# how the expert axis is used to shard attention weights and activations

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,8 @@ class MoEKernels(BaseModel):
685685
wi_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wi.")
686686
wo_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wo.")
687687

688+
merge_gating_gmm: bool = Field(False, description="whether to merge the two gating gmm kernels into one.")
689+
688690

689691
class DeepSeekMoE(BaseModel):
690692
"""Configuration specific to DeepSeek-style MoE layers."""

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -887,23 +887,34 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
887887

888888
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
889889
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
890-
891-
layer_w0 = gmm_fn(
892-
x,
893-
w0,
894-
tiling=wi_tile_size,
895-
weight_gather_axes=wi_gather_axes,
896-
input_buffer_count=wi_input_buffer_count,
897-
combine_scopes=wi_combine_scopes,
898-
)
899-
layer_w1 = gmm_fn(
900-
x,
901-
w1,
902-
tiling=wi_tile_size,
903-
weight_gather_axes=wi_gather_axes,
904-
input_buffer_count=wi_input_buffer_count,
905-
combine_scopes=wi_combine_scopes,
906-
)
890+
if config.merge_gating_gmm:
891+
w01 = jnp.concatenate([w0, w1], axis=-1)
892+
layer_w01 = gmm_fn(
893+
x,
894+
w01,
895+
tiling=wi_tile_size,
896+
weight_gather_axes=wi_gather_axes,
897+
input_buffer_count=wi_input_buffer_count,
898+
combine_scopes=wi_combine_scopes,
899+
)
900+
layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1)
901+
else:
902+
layer_w0 = gmm_fn(
903+
x,
904+
w0,
905+
tiling=wi_tile_size,
906+
weight_gather_axes=wi_gather_axes,
907+
input_buffer_count=wi_input_buffer_count,
908+
combine_scopes=wi_combine_scopes,
909+
)
910+
layer_w1 = gmm_fn(
911+
x,
912+
w1,
913+
tiling=wi_tile_size,
914+
weight_gather_axes=wi_gather_axes,
915+
input_buffer_count=wi_input_buffer_count,
916+
combine_scopes=wi_combine_scopes,
917+
)
907918
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
908919
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
909920
intermediate_layer = jax.nn.silu(layer_w0) * layer_w1

0 commit comments

Comments
 (0)