Skip to content

Commit f62ee44

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Make batch split factor configurable.
PiperOrigin-RevId: 866619829
1 parent 00eb74e commit f62ee44

3 files changed

Lines changed: 16 additions & 8 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ topk_routing_group: -1 # number of top groups to route inputs. For EP,
240240
# Splits the batch to allow for better scheduling when using expert parallelism by overlapping the
241241
# all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers.
242242
use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits.
243+
batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is True.
243244

244245
# For complex architectures like llama4 there are repeated sets of
245246
# inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope]

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,10 @@ class DeepSeekMoE(BaseModel):
692692
False,
693693
description="Whether to split batch into micro-batches to hide communications that yields performance benefits.",
694694
)
695+
batch_split_factor: int = Field(
696+
1,
697+
description="Factor by which to split the batch into micro-batches. Only used if use_batch_split_schedule is True.",
698+
)
695699

696700

697701
class Qwen3Next(BaseModel):

src/MaxText/layers/deepseek_batchsplit.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def fetch_weights(params, dtype):
7171
@jax.named_scope("deepseek_batchsplit_split")
7272
def split(x, split_factor=2):
7373
"""Splits the input into `split_factor` parts along the batch dimension."""
74-
74+
if split_factor == 1:
75+
return [x]
7576
if x is None:
7677
return [None] * split_factor
7778
else:
@@ -80,8 +81,10 @@ def split(x, split_factor=2):
8081

8182

8283
@jax.named_scope("deepseek_batchsplit_merge")
83-
def merge(x):
84+
def merge(x, split_factor=2):
8485
"""Merges the input microbatches back into a single tensor."""
86+
if split_factor == 1:
87+
return x[0]
8588
x = jnp.stack(x, axis=1)
8689
return jnp.reshape(x, (-1,) + x.shape[2:])
8790

@@ -104,13 +107,13 @@ def batch_split_schedule(
104107
None,
105108
)
106109
xs = jax.shard_map(
107-
split,
110+
functools.partial(split, split_factor=cfg.batch_split_factor),
108111
mesh=mesh,
109112
in_specs=activation_pspec,
110-
out_specs=[activation_pspec, activation_pspec],
113+
out_specs=[activation_pspec] * cfg.batch_split_factor,
111114
)(inputs)
112-
dpos = split(positions)
113-
dseg = split(segment_ids)
115+
dpos = split(positions, split_factor=cfg.batch_split_factor)
116+
dseg = split(segment_ids, split_factor=cfg.batch_split_factor)
114117
xs = [with_data_parallel_constraint(x, mesh) for x in xs]
115118
xs = jax.ad_checkpoint.checkpoint_name(xs, "decoder_layer_input")
116119

@@ -186,9 +189,9 @@ def batch_split_schedule(
186189
dtype=cfg.dtype,
187190
)
188191
xs = jax.shard_map(
189-
merge,
192+
functools.partial(merge, split_factor=cfg.batch_split_factor),
190193
mesh=mesh,
191-
in_specs=([activation_pspec, activation_pspec],),
194+
in_specs=([activation_pspec] * cfg.batch_split_factor,),
192195
out_specs=activation_pspec,
193196
)(xs)
194197
return xs

0 commit comments

Comments
 (0)