Skip to content

Commit 863d2a3

Browse files
Merge pull request #3081 from AI-Hypercomputer:amandaliang
PiperOrigin-RevId: 865988464
2 parents 0f85477 + d5b9a6c commit 863d2a3

4 files changed

Lines changed: 49 additions & 8 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ wo_tile_dlhs_mlp_dim: 1024
203203
wo_tile_drhs_batch_seq: 512
204204
wo_tile_drhs_embed_dim: 1024
205205
wo_tile_drhs_mlp_dim: 1024
206+
207+
wi_tile_fwd_buffer_count: 2
208+
wi_tile_dlhs_buffer_count: 2
209+
wi_tile_drhs_buffer_count: 2
210+
wo_tile_fwd_buffer_count: 2
211+
wo_tile_dlhs_buffer_count: 2
212+
wo_tile_drhs_buffer_count: 2
213+
206214
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.
207215

208216
# how the expert axis is used to shard attention weights and activations

src/MaxText/configs/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,13 @@ class MoEKernels(BaseModel):
661661
wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.")
662662
wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.")
663663

664+
wi_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wi.")
665+
wi_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wi.")
666+
wi_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wi.")
667+
wo_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wo.")
668+
wo_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wo.")
669+
wo_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wo.")
670+
664671

665672
class DeepSeekMoE(BaseModel):
666673
"""Configuration specific to DeepSeek-style MoE layers."""

src/MaxText/layers/moe.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def sparse_matmul(
877877
):
878878
"""Perform sparse matrix multiplication of inputs and Experts."""
879879

880-
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
880+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count):
881881
pad_length = self.config.wi_tile_fwd_batch_seq
882882
hs_shape = inputs.shape
883883
# pad length is the 1st dimension of tiling size in gmm call
@@ -916,6 +916,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
916916
use_qwix_quantization=self.config.use_qwix_quantization,
917917
use_tokamax_backend=self.config.use_tokamax_gmm,
918918
weight_gather_axes=weight_gather_axes,
919+
input_buffer_count=input_buffer_count,
919920
)
920921
else:
921922
output = tokamax.ragged_dot(
@@ -1220,22 +1221,42 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12201221
self.config.wo_tile_drhs_embed_dim,
12211222
self.config.wo_tile_drhs_mlp_dim,
12221223
)
1223-
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
1224+
wi_input_buffer_count = (
1225+
self.config.wi_tile_fwd_buffer_count,
1226+
self.config.wi_tile_dlhs_buffer_count,
1227+
self.config.wi_tile_drhs_buffer_count,
1228+
)
1229+
wo_input_buffer_count = (
1230+
self.config.wo_tile_fwd_buffer_count,
1231+
self.config.wo_tile_dlhs_buffer_count,
1232+
self.config.wo_tile_drhs_buffer_count,
1233+
)
1234+
layer_w0 = gmm_fn(
1235+
x, w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, input_buffer_count=wi_input_buffer_count
1236+
)
12241237
if self.get_tensor_transpose_parallelism_size() > 1:
12251238
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
12261239
if self.config.mlp_bias:
12271240
layer_w0 = layer_w0 + w0_bias
12281241
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
12291242

1230-
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
1243+
layer_w1 = gmm_fn(
1244+
x, w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, input_buffer_count=wi_input_buffer_count
1245+
)
12311246
if self.get_tensor_transpose_parallelism_size() > 1:
12321247
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
12331248
if self.config.mlp_bias:
12341249
layer_w1 = layer_w1 + w1_bias
12351250
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
12361251
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
12371252

1238-
intermediate_output = gmm_fn(intermediate_layer, wo, tiling=wo_tile_size, weight_gather_axes=wo_gather_axes)
1253+
intermediate_output = gmm_fn(
1254+
intermediate_layer,
1255+
wo,
1256+
tiling=wo_tile_size,
1257+
weight_gather_axes=wo_gather_axes,
1258+
input_buffer_count=wo_input_buffer_count,
1259+
)
12391260
if self.get_tensor_parallelism_size() > 1:
12401261
intermediate_output = jax.lax.psum_scatter(
12411262
intermediate_output, self._tensor_parallelism_name, scatter_dimension=1, tiled=True

src/maxtext/kernels/megablox/ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def gmm(
4242
use_qwix_quantization: bool = False,
4343
use_tokamax_backend: bool = False,
4444
weight_gather_axes: List[Tuple[str, int]] | None = None,
45+
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
4546
):
4647
"""Grouped matrix multiplication operation."""
4748
quantization_rule = None
@@ -61,14 +62,15 @@ def gmm(
6162
)
6263

6364
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
64-
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
65+
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 8, 9, 10, 11, 12))
6566
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
6667
return gmm_fwd_bwd(
6768
lhs,
6869
rhs,
6970
group_sizes,
7071
preferred_element_type,
7172
tiling,
73+
input_buffer_count,
7274
group_offset,
7375
existing_out,
7476
transpose_rhs,
@@ -85,6 +87,7 @@ def _gmm_fwd(
8587
group_sizes: jnp.ndarray,
8688
preferred_element_type: jnp.dtype = jnp.float32,
8789
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
90+
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
8891
group_offset: jnp.ndarray | None = None,
8992
existing_out: jnp.ndarray | None = None,
9093
transpose_rhs: bool = False,
@@ -125,9 +128,7 @@ def _gmm_fwd(
125128
# QAG is only supported for following conditions
126129
if use_tokamax_backend:
127130
if quantization_rule and quantization_rule.bwd_qtype:
128-
if quantization_rule.weight_calibration_method.startswith(
129-
"fixed"
130-
) and isinstance(rhs, qpl.QArray):
131+
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
131132
if weight_gather_axes:
132133
for axis_name, axis_idx in weight_gather_axes:
133134
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
@@ -142,6 +143,7 @@ def _gmm_fwd(
142143
group_offset=group_offset,
143144
transpose_rhs=transpose_rhs,
144145
interpret=interpret,
146+
input_buffer_count=input_buffer_count[0],
145147
)
146148
else:
147149
out = backend.gmm(
@@ -163,6 +165,7 @@ def _gmm_bwd(
163165
rhs_dtype: jax.typing.DTypeLike,
164166
preferred_element_type: jnp.dtype,
165167
tiling: tuple[int, int, int, int, int, int, int, int, int],
168+
input_buffer_count: tuple[int, int, int],
166169
transpose_rhs: bool,
167170
interpret: bool,
168171
quantization_rule: qwix.QtRule | None,
@@ -229,6 +232,7 @@ def _gmm_bwd(
229232
group_offset=group_offset,
230233
transpose_rhs=not transpose_rhs,
231234
interpret=interpret,
235+
input_buffer_count=input_buffer_count[1],
232236
)
233237
drhs = tokamax_backend.tgmm(
234238
lhs=lhs.swapaxes(0, 1),
@@ -240,6 +244,7 @@ def _gmm_bwd(
240244
group_offset=group_offset,
241245
num_actual_groups=num_actual_groups,
242246
interpret=interpret,
247+
input_buffer_count=input_buffer_count[2],
243248
)
244249
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
245250
# Scatter back in reverse order of gather

0 commit comments

Comments
 (0)