Skip to content

Commit 796eaeb

Browse files
Weight gathering in multiple axes
Replaces the hardcoded `is_fsdp_shard_on_exp` flag with a `weight_gather_axes` parameter in the Megablox GMM kernel. This allows specifying multiple axes and dimensions for gathering quantized weights before computation and scattering gradients back in the backward pass. The MoE layer now computes the necessary gather axes based on the weight's partition spec and quantization settings. This is essential for using multiple weight sharding axes such as `ici_fsdp_transpose_parallelism=2` and `ici_fsdp_parallelism=256` for 256 chip configuration for DeepSeek v3 FP8 quantized Training. PiperOrigin-RevId: 839652914
1 parent 60028c4 commit 796eaeb

2 files changed

Lines changed: 44 additions & 19 deletions

File tree

src/MaxText/kernels/megablox/ops.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import functools
2020
import dataclasses
21-
from typing import Literal
21+
from typing import Literal, List, Tuple
2222
import jax
2323
import jax.numpy as jnp
2424
from MaxText.kernels.megablox import backend
@@ -41,7 +41,7 @@ def gmm(
4141
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
4242
use_qwix_quantization: bool = False,
4343
use_tokamax_backend: bool = False,
44-
is_fsdp_shard_on_exp: bool = False,
44+
weight_gather_axes: List[Tuple[str, int]] | None = None,
4545
):
4646
"""Grouped matrix multiplication operation."""
4747
quantization_rule = None
@@ -75,7 +75,7 @@ def gmm(
7575
interpret,
7676
quantization_rule,
7777
use_tokamax_backend,
78-
is_fsdp_shard_on_exp,
78+
weight_gather_axes,
7979
)
8080

8181

@@ -91,7 +91,7 @@ def _gmm_fwd(
9191
interpret: bool = False,
9292
quantization_rule: qwix.QtRule | None = None,
9393
use_tokamax_backend: bool = False,
94-
is_fsdp_shard_on_exp: bool = False,
94+
weight_gather_axes: List[Tuple[str, int]] | None = None,
9595
) -> tuple[
9696
jnp.ndarray,
9797
tuple[
@@ -128,10 +128,11 @@ def _gmm_fwd(
128128
if (
129129
quantization_rule.weight_calibration_method.startswith("fixed")
130130
and isinstance(rhs, qpl.QArray)
131-
and is_fsdp_shard_on_exp
132131
):
133-
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, "fsdp", axis=0, tiled=True)
134-
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
132+
if weight_gather_axes:
133+
for axis_name, axis_idx in weight_gather_axes:
134+
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
135+
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
135136
out = tokamax_backend.gmm(
136137
lhs=lhs,
137138
rhs=rhs,
@@ -167,7 +168,7 @@ def _gmm_bwd(
167168
interpret: bool,
168169
quantization_rule: qwix.QtRule | None,
169170
use_tokamax_backend: bool,
170-
is_fsdp_shard_on_exp: bool,
171+
weight_gather_axes: List[Tuple[str, int]] | None,
171172
residual: tuple[
172173
jnp.ndarray | qpl.QArray,
173174
jnp.ndarray | qpl.QArray,
@@ -241,8 +242,10 @@ def _gmm_bwd(
241242
num_actual_groups=num_actual_groups,
242243
interpret=interpret,
243244
)
244-
if quantization_rule and quantization_rule.bwd_qtype and is_fsdp_shard_on_exp:
245-
drhs = jax.lax.psum_scatter(drhs, "fsdp", scatter_dimension=0, tiled=True)
245+
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
246+
# Scatter back in reverse order of gather
247+
for axis_name, axis_idx in reversed(weight_gather_axes):
248+
drhs = jax.lax.psum_scatter(drhs, axis_name, scatter_dimension=axis_idx, tiled=True)
246249
else:
247250
dlhs = backend.gmm(
248251
dlhs_dout,

src/MaxText/layers/moe.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
3636
from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
3737
import numpy as np
38+
import qwix.pallas as qpl
3839
import tokamax
3940

4041
set_xla_metadata = xla_metadata.set_xla_metadata
@@ -792,7 +793,7 @@ def sparse_matmul(
792793
):
793794
"""Perform sparse matrix multiplication of inputs and Experts."""
794795

795-
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
796+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
796797
pad_length = self.config.wi_tile_fwd_batch_seq
797798
hs_shape = inputs.shape
798799
# pad length is the 1st dimension of tiling size in gmm call
@@ -830,7 +831,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
830831
rhs_quantize_dtype=rhs_quantize_dtype,
831832
use_qwix_quantization=self.config.use_qwix_quantization,
832833
use_tokamax_backend=self.config.use_tokamax_gmm,
833-
is_fsdp_shard_on_exp=self.config.fsdp_shard_on_exp,
834+
weight_gather_axes=weight_gather_axes,
834835
)
835836
else:
836837
output = tokamax.ragged_dot(
@@ -853,7 +854,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
853854
rhs_quantize_dtype=rhs_quantize_dtype,
854855
use_qwix_quantization=self.config.use_qwix_quantization,
855856
use_tokamax_backend=self.config.use_tokamax_gmm,
856-
is_fsdp_shard_on_exp=self.config.fsdp_shard_on_exp,
857+
weight_gather_axes=weight_gather_axes,
857858
)
858859
else:
859860
rhs_inputs = kernel
@@ -935,12 +936,15 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
935936

936937
# w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use
937938
# mlp_no_fsdp axis
939+
weight_gather = False
938940
if self.config.fsdp_shard_on_exp:
939-
if self.config.quantization:
940-
# special sharding when quantization is enabled with fsdp_shard_on_exp
941+
quantization_rule = qpl.get_current_rule("gmm")
942+
if quantization_rule and quantization_rule.weight_calibration_method.startswith("fixed"):
943+
# special sharding when using static scaling for weights in quantization with fsdp_shard_on_exp
941944
w0_pspec = nn.logical_to_mesh_axes(self.wi_kernel_axes)
942945
w1_pspec = nn.logical_to_mesh_axes(self.wi_kernel_axes)
943946
wo_pspec = nn.logical_to_mesh_axes(self.wo_kernel_axes)
947+
weight_gather = True
944948
else:
945949
# special sharding for dsv3 to remove overhead between gmm/AG
946950
w0_pspec = nn.logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
@@ -1069,7 +1073,25 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10691073

10701074
if self.config.mlp_bias:
10711075
w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias)
1072-
1076+
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
1077+
if pspec_dim_axes is None: return []
1078+
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
1079+
active = []
1080+
for ax in axes:
1081+
if ax and self.mesh.shape.get(ax, 1) > 1:
1082+
active.append((ax, tensor_dim_index))
1083+
return active
1084+
wi_gather_axes = []
1085+
wo_gather_axes = []
1086+
1087+
if weight_gather:
1088+
# wi [Experts, In, Hidden] -> Gather Exp(0) and Hidden(2)
1089+
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[0], 0))
1090+
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[2], 2))
1091+
1092+
# wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
1093+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
1094+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
10731095
gmm_fn = functools.partial(
10741096
gmm,
10751097
group_sizes=group_sizes,
@@ -1097,22 +1119,22 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10971119
self.config.wo_tile_drhs_embed_dim,
10981120
self.config.wo_tile_drhs_mlp_dim,
10991121
)
1100-
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size)
1122+
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
11011123
if self.get_tensor_transpose_parallelism_size() > 1:
11021124
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
11031125
if self.config.mlp_bias:
11041126
layer_w0 = layer_w0 + w0_bias
11051127
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
11061128

1107-
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size)
1129+
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
11081130
if self.get_tensor_transpose_parallelism_size() > 1:
11091131
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
11101132
if self.config.mlp_bias:
11111133
layer_w1 = layer_w1 + w1_bias
11121134
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
11131135
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
11141136

1115-
intermediate_output = gmm_fn(intermediate_layer, wo, tiling=wo_tile_size)
1137+
intermediate_output = gmm_fn(intermediate_layer, wo, tiling=wo_tile_size, weight_gather_axes=wo_gather_axes)
11161138
if self.get_tensor_parallelism_size() > 1:
11171139
intermediate_output = jax.lax.psum_scatter(intermediate_output, "tensor", scatter_dimension=1, tiled=True)
11181140
if self.config.mlp_bias:

0 commit comments

Comments
 (0)