Skip to content

Commit 519ff77

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Refactor execution of batch-split deepseek sparse layers in Decoder to use pure JAX.
PiperOrigin-RevId: 875276104
1 parent 46ec3af commit 519ff77

3 files changed

Lines changed: 126 additions & 29 deletions

File tree

src/maxtext/layers/decoders.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from maxtext.layers.quantizations import AqtQuantization as Quant
4242
from maxtext.models import (
4343
deepseek,
44+
deepseek_batchsplit,
4445
gemma,
4546
gemma2,
4647
gemma3,
@@ -865,15 +866,32 @@ def __call__(
865866
moe_layer = RemattedBlockLayers[1]
866867
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
867868
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
868-
y, _ = self.scan_decoder_layers(
869-
cfg,
870-
moe_layer,
871-
num_moe_layers,
872-
"moe_layers",
873-
mesh,
874-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
875-
model_mode=model_mode,
876-
)(y, *broadcast_args)
869+
870+
# If batch-split schedule is used and initialization is complete,
871+
# as detected by immutable params, use deepseek_batchsplit custom
872+
# scan with initialized parameters.
873+
if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"):
874+
y = deepseek_batchsplit.scan_batch_split_layers(
875+
y,
876+
self.variables["params"]["moe_layers"],
877+
decoder_positions,
878+
decoder_segment_ids,
879+
model_mode=model_mode,
880+
mesh=mesh,
881+
quant=self.quant,
882+
cfg=cfg,
883+
policy=policy,
884+
)
885+
else:
886+
y, _ = self.scan_decoder_layers(
887+
cfg,
888+
moe_layer,
889+
num_moe_layers,
890+
"moe_layers",
891+
mesh,
892+
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
893+
model_mode=model_mode,
894+
)(y, *broadcast_args)
877895
elif cfg.decoder_block == DecoderBlockType.GEMMA3:
878896
y = self._apply_gemma3_scanned_blocks(
879897
y,

src/maxtext/models/deepseek.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

19+
import functools
1920
from typing import Optional
2021

2122
from flax import nnx
23+
import jax
2224
from jax.ad_checkpoint import checkpoint_name
2325
import jax.numpy as jnp
2426
from jax.sharding import Mesh
@@ -58,7 +60,6 @@ def __init__(
5860
rngs: nnx.Rngs,
5961
quant: Optional[quantizations.AqtQuantization] = None,
6062
) -> None:
61-
6263
self.config = config
6364
self.model_mode = model_mode
6465
self.mesh = mesh
@@ -350,7 +351,6 @@ def __init__(
350351
rngs: nnx.Rngs,
351352
quant: Optional[quantizations.AqtQuantization] = None,
352353
) -> None:
353-
354354
super().__init__(config, model_mode, mesh, rngs, quant)
355355
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
356356
config=self.config,
@@ -380,18 +380,48 @@ def __call__(
380380
if isinstance(inputs, tuple):
381381
inputs = inputs[0]
382382

383-
# If using batch split schedule, call the batch split version of the layer.
383+
# This code should only be traced during initialization when using
384+
# batch-split schedule. It is never run during model execution, since
385+
# `Decoder` directly calls `batch_split_schedule` during execution.
386+
# That is also why we can split/merge activations here as well as
387+
# in `Decoder`, since they will never be executed together.
384388
if self.config.use_batch_split_schedule:
389+
activation_pspec = jax.sharding.PartitionSpec(
390+
("data", "fsdp", "fsdp_transpose", "expert", "context"),
391+
None,
392+
None,
393+
)
394+
inputs = jax.shard_map(
395+
functools.partial(
396+
deepseek_batchsplit.split,
397+
split_factor=self.config.batch_split_factor,
398+
),
399+
mesh=self.mesh,
400+
in_specs=activation_pspec,
401+
out_specs=[activation_pspec] * self.config.batch_split_factor,
402+
)(inputs)
403+
dpos = deepseek_batchsplit.split(decoder_positions, self.config.batch_split_factor)
404+
dseg = deepseek_batchsplit.split(decoder_segment_ids, self.config.batch_split_factor)
405+
weights = deepseek_batchsplit.fetch_weights(nnx.to_pure_dict(nnx.state(self, nnx.Param)), self.config.dtype)
385406
outputs = deepseek_batchsplit.batch_split_schedule(
386407
inputs,
387-
nnx.to_pure_dict(nnx.state(self, nnx.Param)),
388-
decoder_positions,
389-
decoder_segment_ids,
408+
weights,
409+
dpos,
410+
dseg,
390411
model_mode=model_mode,
391412
mesh=self.mesh,
392413
quant=self.quant,
393414
cfg=self.config,
394415
)
416+
outputs = jax.shard_map(
417+
functools.partial(
418+
deepseek_batchsplit.merge,
419+
split_factor=self.config.batch_split_factor,
420+
),
421+
mesh=self.mesh,
422+
in_specs=([activation_pspec] * self.config.batch_split_factor,),
423+
out_specs=activation_pspec,
424+
)(outputs)
395425
return outputs, None
396426

397427
x = self.with_logical_constraint(inputs)

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def _q_psum_scatter_bwd(
106106
def fetch_weights(params, dtype):
107107
"""Fetches weights from params in the proper format for batch-split schedule."""
108108
return jax.tree.map(
109-
lambda x: jnp.asarray(x[...], dtype),
109+
# If x is a LogicallyPartitioned array, then x.value is the underlying
110+
# array. If not, use the array directly.
111+
lambda x: jnp.asarray(getattr(x, "value", x)[...], dtype),
110112
(
111113
(
112114
(
@@ -165,7 +167,7 @@ def merge(x, split_factor=2):
165167
return jnp.reshape(x, (-1,) + x.shape[2:])
166168

167169

168-
def batch_split_schedule(
170+
def scan_batch_split_layers(
169171
inputs,
170172
params,
171173
positions,
@@ -175,22 +177,75 @@ def batch_split_schedule(
175177
mesh,
176178
quant,
177179
cfg,
180+
policy,
178181
):
179-
"""Applies the DeepSeek MoE layer with batch-split schedule."""
182+
"""Scans the layers with batch-split schedule."""
183+
184+
def batch_split_scan_fn(inputs, weights, dpos, dseg):
185+
xs = batch_split_schedule(
186+
inputs,
187+
weights,
188+
dpos,
189+
dseg,
190+
model_mode=model_mode,
191+
mesh=mesh,
192+
quant=quant,
193+
cfg=cfg,
194+
)
195+
return xs, None
196+
197+
batch_split_scan_fn_checkpointed = jax.checkpoint(
198+
batch_split_scan_fn,
199+
# No need to prevent CSE inside scan.
200+
prevent_cse=False,
201+
policy=policy,
202+
)
203+
weights = fetch_weights(params, cfg.dtype)
204+
# `jax.lax.scan` expects the leading dimension of weights to be the scan
205+
# dimension, but the weights are initialized/loaded with the param scan
206+
# axis as the scan dimension, so swap the axes.
207+
weights = jax.tree.map(lambda x: jnp.swapaxes(x, 0, cfg.param_scan_axis), weights)
208+
180209
activation_pspec = jax.sharding.PartitionSpec(
181210
("data", "fsdp", "fsdp_transpose", "expert", "context"),
182211
None,
183212
None,
184213
)
185-
xs = jax.shard_map(
214+
inputs = jax.shard_map(
186215
functools.partial(split, split_factor=cfg.batch_split_factor),
187216
mesh=mesh,
188217
in_specs=activation_pspec,
189218
out_specs=[activation_pspec] * cfg.batch_split_factor,
190219
)(inputs)
191220
dpos = split(positions, split_factor=cfg.batch_split_factor)
192221
dseg = split(segment_ids, split_factor=cfg.batch_split_factor)
193-
xs = [with_data_parallel_constraint(x, mesh) for x in xs]
222+
outputs, _ = jax.lax.scan(
223+
functools.partial(batch_split_scan_fn_checkpointed, dpos=dpos, dseg=dseg),
224+
inputs,
225+
weights,
226+
)
227+
outputs = jax.shard_map(
228+
functools.partial(merge, split_factor=cfg.batch_split_factor),
229+
mesh=mesh,
230+
in_specs=([activation_pspec] * cfg.batch_split_factor,),
231+
out_specs=activation_pspec,
232+
)(outputs)
233+
return outputs
234+
235+
236+
def batch_split_schedule(
237+
inputs,
238+
weights,
239+
positions,
240+
segment_ids,
241+
*,
242+
model_mode,
243+
mesh,
244+
quant,
245+
cfg,
246+
):
247+
"""Applies the DeepSeek MoE layer with batch-split schedule."""
248+
xs = [with_data_parallel_constraint(x, mesh) for x in inputs]
194249
xs = jax.ad_checkpoint.checkpoint_name(xs, "decoder_layer_input")
195250

196251
attn_op = attention_op.AttentionOp(
@@ -207,12 +262,12 @@ def batch_split_schedule(
207262
dtype=cfg.dtype,
208263
attention_type=cfg.attention_type,
209264
)
210-
norm_mla_ws, moe_ws = fetch_weights(params, cfg.dtype)
265+
norm_mla_ws, moe_ws = weights
211266
xs = mla_with_norms(
212267
xs,
213268
norm_mla_ws,
214-
dpos,
215-
dseg,
269+
positions,
270+
segment_ids,
216271
mesh=mesh,
217272
model_mode=model_mode,
218273
attn_op=attn_op,
@@ -242,12 +297,6 @@ def batch_split_schedule(
242297
use_gather_mosaic_kernel=False,
243298
config=cfg,
244299
)
245-
xs = jax.shard_map(
246-
functools.partial(merge, split_factor=cfg.batch_split_factor),
247-
mesh=mesh,
248-
in_specs=([activation_pspec] * cfg.batch_split_factor,),
249-
out_specs=activation_pspec,
250-
)(xs)
251300
return xs
252301

253302

0 commit comments

Comments
 (0)