Skip to content

Commit bf63749

Browse files
committed
Support FP8 in batch split config
1 parent 55b57ff commit bf63749

5 files changed

Lines changed: 262 additions & 68 deletions

File tree

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 219 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,95 @@
1515

1616
"""Alternative DeepSeek model definition with batch-split schedule."""
1717

18+
import dataclasses
1819
import functools
1920
import math
20-
from typing import Sequence
21+
from typing import Any, Sequence
2122

2223
import jax
2324
import jax.numpy as jnp
24-
from maxtext.kernels import megablox
25-
from maxtext.kernels import sort_activations
25+
import qwix.pallas as qpl
26+
import tokamax
27+
from flax import linen as nn
28+
29+
from maxtext.kernels import megablox, sort_activations
2630
from MaxText.layers import attention_op
31+
from MaxText.layers import moe as moe_lib
2732
from MaxText.layers import quantizations
2833

2934

35+
@functools.partial(
36+
jax.custom_vjp,
37+
nondiff_argnums=(
38+
1,
39+
2,
40+
3,
41+
),
42+
)
43+
def quantized_psum_scatter(x: jax.Array, axis_name: str, scatter_dimension: int, tiled: bool) -> jax.Array:
44+
"""Forward: Standard BF16 Reduce-Scatter.
45+
46+
Backward: Quantized FP8 All-Gather (DeepSeek optimization).
47+
48+
Args:
49+
x: The input tensor.
50+
axis_name: The axis name for the psum_scatter/all_gather operation.
51+
scatter_dimension: The dimension along which to scatter.
52+
tiled: Whether the scatter/gather is tiled.
53+
54+
Returns:
55+
The result of the reduce-scatter operation.
56+
"""
57+
return _q_psum_scatter_fwd(x, axis_name, scatter_dimension, tiled)[0]
58+
59+
60+
def _q_psum_scatter_fwd(x: jax.Array, axis_name: str, scatter_dimension: int, tiled: bool) -> tuple[jax.Array, None]:
61+
out = jax.lax.psum_scatter(x, axis_name=axis_name, scatter_dimension=scatter_dimension, tiled=tiled)
62+
return out, None
63+
64+
65+
def _q_psum_scatter_bwd(
66+
axis_name: str,
67+
scatter_dimension: int,
68+
tiled: bool,
69+
res: Any,
70+
grads: jax.Array,
71+
) -> tuple[jax.Array]: # pylint: disable=g-one-element-tuple
72+
"""Backward pass for quantized_psum_scatter.
73+
74+
Performs a quantized All-Gather of the gradients.
75+
76+
Args:
77+
axis_name: The axis name for the all_gather operation.
78+
scatter_dimension: The dimension along which the scatter occurred in the
79+
forward pass.
80+
tiled: Whether the gather is tiled.
81+
res: The residuals from the forward pass (_q_psum_scatter_fwd), containing
82+
(axis_name, scatter_dimension, tiled).
83+
grads: The gradients from the next layer, which are in BF16.
84+
85+
Returns:
86+
The dequantized and all-gathered gradients.
87+
"""
88+
del res
89+
# --- BACKWARD PASS (Dispatch) ---
90+
# 'grads' is the BF16 gradient arriving from the next layer.
91+
# We need to broadcast it back to all devices (All-Gather).
92+
93+
grads_q = qpl.quantize(
94+
grads,
95+
jnp.float8_e5m2,
96+
channelwise_axes=[0],
97+
)
98+
99+
gathered_qvals = jax.lax.all_gather(grads_q.qvalue, axis_name=axis_name, tiled=tiled, axis=scatter_dimension)
100+
101+
return (qpl.dequantize(dataclasses.replace(grads_q, qvalue=gathered_qvals)),)
102+
103+
104+
quantized_psum_scatter.defvjp(_q_psum_scatter_fwd, _q_psum_scatter_bwd)
105+
106+
30107
def fetch_weights(params, dtype):
31108
"""Fetches weights from params in the proper format for batch-split schedule."""
32109
return jax.tree.map(
@@ -164,29 +241,7 @@ def batch_split_schedule(
164241
routed_scaling_factor=cfg.routed_scaling_factor,
165242
expert_axis_name="expert",
166243
use_gather_mosaic_kernel=False,
167-
wi_tile_size=(
168-
cfg.wi_tile_fwd_batch_seq,
169-
cfg.wi_tile_fwd_embed_dim,
170-
cfg.wi_tile_fwd_mlp_dim,
171-
cfg.wi_tile_dlhs_batch_seq,
172-
cfg.wi_tile_dlhs_embed_dim,
173-
cfg.wi_tile_dlhs_mlp_dim,
174-
cfg.wi_tile_drhs_batch_seq,
175-
cfg.wi_tile_drhs_embed_dim,
176-
cfg.wi_tile_drhs_mlp_dim,
177-
),
178-
wo_tile_size=(
179-
cfg.wo_tile_fwd_batch_seq,
180-
cfg.wo_tile_fwd_embed_dim,
181-
cfg.wo_tile_fwd_mlp_dim,
182-
cfg.wo_tile_dlhs_batch_seq,
183-
cfg.wo_tile_dlhs_embed_dim,
184-
cfg.wo_tile_dlhs_mlp_dim,
185-
cfg.wo_tile_drhs_batch_seq,
186-
cfg.wo_tile_drhs_embed_dim,
187-
cfg.wo_tile_drhs_mlp_dim,
188-
),
189-
dtype=cfg.dtype,
244+
config=cfg,
190245
)
191246
xs = jax.shard_map(
192247
functools.partial(merge, split_factor=cfg.batch_split_factor),
@@ -581,9 +636,7 @@ def moe(
581636
routed_scaling_factor,
582637
expert_axis_name,
583638
use_gather_mosaic_kernel,
584-
wi_tile_size,
585-
wo_tile_size,
586-
dtype,
639+
config,
587640
):
588641
"""Performs dropless MoE with tensor/expert parallelism."""
589642
xs, ys = list(zip(*inputs))
@@ -597,9 +650,7 @@ def moe(
597650
routed_scaling_factor=routed_scaling_factor,
598651
expert_axis_name=expert_axis_name,
599652
use_gather_mosaic_kernel=use_gather_mosaic_kernel,
600-
wi_tile_size=wi_tile_size,
601-
wo_tile_size=wo_tile_size,
602-
dtype=dtype,
653+
config=config,
603654
),
604655
mesh,
605656
)
@@ -691,20 +742,133 @@ def unroute(
691742
return jax.lax.psum_scatter(x, expert_axis_name, scatter_dimension=0, tiled=True)
692743

693744

694-
def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size, dtype):
745+
def compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh):
695746
"""Processes routed tokens through the MLP."""
696-
gmm_fn = functools.partial(
697-
megablox.gmm,
698-
group_sizes=group_sizes,
699-
preferred_element_type=dtype,
747+
748+
def gmm(
749+
inputs,
750+
kernel,
751+
tiling,
752+
group_sizes,
753+
preferred_element_type,
754+
weight_gather_axes,
755+
input_buffer_count,
756+
combine_scopes,
757+
):
758+
if config.use_qwix_quantization:
759+
output = megablox.gmm(
760+
lhs=inputs,
761+
rhs=kernel,
762+
group_sizes=group_sizes,
763+
preferred_element_type=preferred_element_type,
764+
tiling=tiling,
765+
use_qwix_quantization=config.use_qwix_quantization,
766+
use_tokamax_backend=config.use_tokamax_gmm,
767+
weight_gather_axes=weight_gather_axes,
768+
input_buffer_count=input_buffer_count,
769+
combine_scopes=combine_scopes,
770+
)
771+
else:
772+
output = tokamax.ragged_dot(
773+
lhs=inputs,
774+
rhs=kernel,
775+
group_sizes=group_sizes,
776+
precision=jax.lax.Precision.DEFAULT,
777+
preferred_element_type=preferred_element_type,
778+
implementation="mosaic",
779+
)
780+
return output
781+
782+
gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype)
783+
wi_gather_axes = []
784+
wo_gather_axes = []
785+
786+
wi_tile_size = (
787+
config.wi_tile_fwd_batch_seq,
788+
config.wi_tile_fwd_embed_dim,
789+
config.wi_tile_fwd_mlp_dim,
790+
config.wi_tile_dlhs_batch_seq,
791+
config.wi_tile_dlhs_embed_dim,
792+
config.wi_tile_dlhs_mlp_dim,
793+
config.wi_tile_drhs_batch_seq,
794+
config.wi_tile_drhs_embed_dim,
795+
config.wi_tile_drhs_mlp_dim,
796+
)
797+
wo_tile_size = (
798+
config.wo_tile_fwd_batch_seq,
799+
config.wo_tile_fwd_embed_dim,
800+
config.wo_tile_fwd_mlp_dim,
801+
config.wo_tile_dlhs_batch_seq,
802+
config.wo_tile_dlhs_embed_dim,
803+
config.wo_tile_dlhs_mlp_dim,
804+
config.wo_tile_drhs_batch_seq,
805+
config.wo_tile_drhs_embed_dim,
806+
config.wo_tile_drhs_mlp_dim,
807+
)
808+
wi_input_buffer_count = (
809+
config.wi_tile_fwd_buffer_count,
810+
config.wi_tile_dlhs_buffer_count,
811+
config.wi_tile_drhs_buffer_count,
812+
)
813+
wo_input_buffer_count = (
814+
config.wo_tile_fwd_buffer_count,
815+
config.wo_tile_dlhs_buffer_count,
816+
config.wo_tile_drhs_buffer_count,
817+
)
818+
819+
wi_combine_scopes = config.wi_combine_scopes
820+
wo_combine_scopes = config.wo_combine_scopes
821+
if config.use_qwix_quantization:
822+
gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes()
823+
w0_pspec = nn.logical_to_mesh_axes(gating_pspec)
824+
wo_pspec = nn.logical_to_mesh_axes(linear_pspec)
825+
ignored_axes = ("expert", "tensor", "tensor_transpose")
826+
827+
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
828+
if pspec_dim_axes is None:
829+
return []
830+
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
831+
active = []
832+
for ax in axes:
833+
if ax and ax not in ignored_axes and mesh.shape.get(ax, 1) > 1:
834+
active.append((ax, tensor_dim_index))
835+
return active
836+
837+
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[0], 0))
838+
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[2], 2))
839+
840+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
841+
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
842+
843+
layer_w0 = gmm_fn(
844+
x,
845+
w0,
846+
tiling=wi_tile_size,
847+
weight_gather_axes=wi_gather_axes,
848+
input_buffer_count=wi_input_buffer_count,
849+
combine_scopes=wi_combine_scopes,
850+
)
851+
layer_w1 = gmm_fn(
852+
x,
853+
w1,
854+
tiling=wi_tile_size,
855+
weight_gather_axes=wi_gather_axes,
856+
input_buffer_count=wi_input_buffer_count,
857+
combine_scopes=wi_combine_scopes,
700858
)
701-
layer_w0 = gmm_fn(x, w0, tiling=wi_tile_size)
702-
layer_w1 = gmm_fn(x, w1, tiling=wi_tile_size)
703859
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
704860
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
705861
intermediate_layer = jax.nn.silu(layer_w0) * layer_w1
706862
intermediate_layer *= weights[:, None]
707-
return gmm_fn(intermediate_layer, wo, tiling=wo_tile_size)
863+
layer_wo = gmm_fn(
864+
intermediate_layer,
865+
wo,
866+
tiling=wo_tile_size,
867+
weight_gather_axes=wo_gather_axes,
868+
input_buffer_count=wo_input_buffer_count,
869+
combine_scopes=wo_combine_scopes,
870+
)
871+
return layer_wo
708872

709873

710874
def route_compute_unroute(
@@ -716,9 +880,8 @@ def route_compute_unroute(
716880
routed_scaling_factor,
717881
expert_axis_name,
718882
use_gather_mosaic_kernel,
719-
wi_tile_size,
720-
wo_tile_size,
721-
dtype,
883+
config,
884+
mesh,
722885
):
723886
"""Routes, processes, and unroutes activations."""
724887
orig_shape = xs[0].shape
@@ -760,9 +923,8 @@ def compute_fn(inputs):
760923
routed_wo,
761924
group_sizes,
762925
weights,
763-
wi_tile_size=wi_tile_size,
764-
wo_tile_size=wo_tile_size,
765-
dtype=dtype,
926+
config=config,
927+
mesh=mesh,
766928
)
767929
return x, y, selected_experts
768930

@@ -792,21 +954,21 @@ def process_activations(
792954
routed_scaling_factor,
793955
expert_axis_name,
794956
use_gather_mosaic_kernel,
795-
wi_tile_size,
796-
wo_tile_size,
797-
dtype,
957+
config,
798958
):
799959
"""Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights."""
800960
activation_pspec = jax.sharding.PartitionSpec(
801961
("data", "fsdp", "fsdp_transpose", "expert", "context"),
802962
None,
803963
None,
804964
)
805-
gating_pspec, linear_pspec = (
806-
jax.sharding.PartitionSpec(None, None, expert_axis_name),
807-
jax.sharding.PartitionSpec(None, expert_axis_name, None),
808-
)
809-
965+
if config.use_qwix_quantization:
966+
gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes()
967+
gating_pspec = nn.logical_to_mesh_axes(gating_pspec)
968+
linear_pspec = nn.logical_to_mesh_axes(linear_pspec)
969+
else:
970+
gating_pspec = jax.sharding.PartitionSpec(None, None, expert_axis_name)
971+
linear_pspec = jax.sharding.PartitionSpec(None, expert_axis_name, None)
810972
return jax.shard_map(
811973
functools.partial(
812974
route_compute_unroute,
@@ -815,9 +977,8 @@ def process_activations(
815977
routed_scaling_factor=routed_scaling_factor,
816978
expert_axis_name=expert_axis_name,
817979
use_gather_mosaic_kernel=use_gather_mosaic_kernel,
818-
wi_tile_size=wi_tile_size,
819-
wo_tile_size=wo_tile_size,
820-
dtype=dtype,
980+
config=config,
981+
mesh=mesh,
821982
),
822983
mesh=mesh,
823984
in_specs=(
@@ -841,4 +1002,4 @@ def process_activations(
8411002
),
8421003
out_specs=activation_pspec,
8431004
check_vma=False,
844-
)([x.astype(dtype) for x in xs], weights)
1005+
)([x.astype(config.dtype) for x in xs], weights)

0 commit comments

Comments
 (0)