Skip to content

Commit 351eebc

Browse files
committed
support attention data parallelism
1 parent 05bf24c commit 351eebc

11 files changed

Lines changed: 88 additions & 52 deletions

File tree

src/MaxText/common_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232

3333
BATCH = "activation_batch"
3434
BATCH_NO_EXP = "activation_batch_no_exp"
35+
36+
ATTN_LENGTH = "activation_attn_length"
37+
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"
38+
3539
LENGTH = "activation_length"
3640
LENGTH_NO_EXP = "activation_length_no_exp"
3741
PREFILL_LENGTH = "prefill_activation_length"
@@ -40,6 +44,7 @@
4044
Q_LORA_UP_PROJ = "q_lora_up_proj"
4145
KV_LENGTH = "activation_kv_length"
4246
KV_LORA_UP_PROJ = "kv_lora_up_proj"
47+
ATTN_EMBED = "activation_attn_embed"
4348
EMBED = "activation_embed"
4449
HEAD = "activation_heads"
4550
PREFILL_KV_BATCH = "activation_prefill_kv_batch"

src/MaxText/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ logical_axis_rules: [
393393
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
394394
['activation_length', ['sequence', 'context', 'expert']],
395395
['activation_length', ['context', 'expert']],
396+
['activation_attn_length', ['sequence', 'context', 'expert']],
397+
['activation_attn_length', ['context', 'expert']],
398+
['activation_attn_length_no_exp', ['sequence', 'context']],
399+
['activation_attn_length_no_exp', ['context']],
396400
['activation_length_no_exp', ['sequence', 'context']],
397401
['activation_length_no_exp', ['context']],
398402
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
@@ -401,6 +405,7 @@ logical_axis_rules: [
401405
['prefill_activation_length', ['sequence', 'context']],
402406
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
403407
['activation_kv_length', []],
408+
['activation_attn_embed', ['tensor', 'tensor_transpose']],
404409
['activation_embed', ['tensor', 'tensor_transpose']],
405410
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
406411
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],

src/MaxText/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,6 +2220,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22202220
"model": self.ici_tensor_parallelism,
22212221
"expert": self.ici_expert_parallelism,
22222222
"autoregressive": self.ici_autoregressive_parallelism,
2223+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
22232224
}
22242225
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
22252226

@@ -2237,6 +2238,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22372238
"model": self.dcn_tensor_parallelism,
22382239
"expert": self.dcn_expert_parallelism,
22392240
"autoregressive": self.dcn_autoregressive_parallelism,
2241+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
22402242
}
22412243
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
22422244

src/MaxText/configs/vllm.yml

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,49 @@ weight_dtype: bfloat16
2525

2626

2727
# -------------- Logical Axis Rules --------------
28-
mesh_axes: ['data', 'model', 'expert']
28+
mesh_axes: ['data', 'attn_dp', 'model', 'expert']
2929
logical_axis_rules: [
3030
['activation_batch', ['expert']],
3131
['activation_batch_no_exp', []],
3232
['activation_embed_and_logits_batch', ['expert']],
3333
['activation_embed_and_logits_batch_sequence', ['expert']],
3434
['activation_heads', ['model']],
3535
['activation_kv_heads', ['model']],
36+
['activation_attn_length', ['expert']],
37+
['activation_attn_length_no_exp', []],
3638
['activation_length', ['data', 'expert']],
37-
['activation_q_length', ['data', 'expert']],
38-
['activation_embed', ['model']],
39-
['activation_mlp', ['model']],
39+
['activation_length_no_exp', 'data'],
40+
['activation_q_length', ['expert']],
41+
['activation_attn_embed', 'model'],
42+
['activation_embed', ['model', 'attn_dp']],
43+
['activation_mlp', ['model', 'attn_dp']],
4044
['activation_kv', ['model']],
4145
['activation_prefill_kv_batch', ['expert']],
4246
['activation_kv_batch', ['expert']],
4347
['activation_kv_batch_no_exp', []],
4448
['activation_kv_head_dim', ['model']],
45-
['activation_vocab', ['model']],
46-
['activation_embed', ['model']],
49+
['activation_vocab', ['model', 'attn_dp']],
50+
['activation_norm_length', []],
4751
['activation_exp', ['expert']],
4852
['decode_batch', ['expert']],
49-
['mlp', ['model']],
50-
['mlp_no_fsdp', ['model']],
51-
['vocab', ['model']],
53+
['decode_length', []],
54+
['mlp', ['model', 'attn_dp']],
55+
['mlp_no_fsdp', ['model', 'attn_dp']],
56+
['vocab', ['model', 'attn_dp']],
5257
['heads', ['model']],
5358
['q_heads', ['model']],
5459
['kv_heads', ['model']],
5560
['kv_head_dim', []],
5661
['kv', []],
5762
['embed', ['expert']],
63+
['embed_tensor_transpose', ['attn_dp', 'model']],
5864
['embed_no_exp', []],
5965
['q_lora', ['expert']],
6066
['kv_lora', ['expert']],
61-
['norm', ['model']],
67+
['norm', []],
6268
['cache_heads', ['model']],
6369
['exp', ['expert']],
6470
['paged_kv_heads', ['model']],
6571
]
66-
data_sharding: [['data', 'model', 'expert']]
72+
data_sharding: [['data', 'attn_dp', 'model', 'expert']]
6773
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/MaxText/layers/attentions.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
D_KV,
3535
AxisNames,
3636
AxisIdxes,
37-
LENGTH,
38-
LENGTH_NO_EXP,
37+
ATTN_LENGTH,
38+
ATTN_LENGTH_NO_EXP,
3939
DType,
4040
Config,
4141
Array,
@@ -46,7 +46,7 @@
4646
KV_HEAD_DIM,
4747
KV_BATCH,
4848
KV_BATCH_NO_EXP,
49-
EMBED,
49+
ATTN_EMBED,
5050
MODEL_MODE_AUTOREGRESSIVE,
5151
MODEL_MODE_TRAIN,
5252
MODEL_MODE_PREFILL,
@@ -141,18 +141,18 @@ def attention_as_linen(
141141
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
142142
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
143143
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
144-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
145-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
146-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
147-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
148-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
149-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
150-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
151-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
152-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
153-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
154-
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
155-
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
144+
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
145+
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
146+
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
147+
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
148+
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
149+
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
150+
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
151+
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED),
152+
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
153+
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV),
154+
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
155+
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
156156
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
157157
decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
158158
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
@@ -300,18 +300,18 @@ def __init__(
300300
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
301301
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
302302
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
303-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
304-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
305-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
306-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
307-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
308-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
309-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
310-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
311-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
312-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
313-
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
314-
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
303+
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
304+
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
305+
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
306+
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
307+
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
308+
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
309+
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
310+
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED),
311+
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
312+
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV),
313+
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
314+
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
315315
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
316316
decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
317317
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3),

src/MaxText/layers/moe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def __init__(
355355

356356
if self.config.attention == "vllm_rpa":
357357
# vLLM uses 'model' as the tensor parallelism axis name
358-
self._tensor_parallelism_name = "model"
358+
self._tensor_parallelism_name = ("model", "attn_dp")
359359
else:
360360
self._tensor_parallelism_name = "tensor"
361361

@@ -459,6 +459,11 @@ def get_expert_parallelism_size(self):
459459
return self.mesh.shape.get("expert", 1)
460460

461461
def get_tensor_parallelism_size(self):
462+
if isinstance(self._tensor_parallelism_name, tuple):
463+
size = 1
464+
for axis in self._tensor_parallelism_name:
465+
size *= self.mesh.shape.get(axis, 1)
466+
return size
462467
return self.mesh.shape.get(self._tensor_parallelism_name, 1)
463468

464469
def get_tensor_transpose_parallelism_size(self):

src/MaxText/maxtext_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,12 +1175,12 @@ def schedule(step):
11751175
return optax.join_schedules(pieces, boundaries)
11761176

11771177

1178-
def print_state_mesh_shardings_params(state, state_sharding, mesh):
1178+
def print_shardings_params(params, params_sharding, mesh):
11791179
"""Print state shardings."""
1180-
leaves_params, _ = jax.tree_util.tree_flatten_with_path(state.params)
1181-
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(state_sharding.params)
1180+
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
1181+
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
11821182
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1183-
path_str = "/".join(str(p.key) for p in path)
1183+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
11841184
shape = jax.typeof(leaf_val)
11851185
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
11861186
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")

src/MaxText/model_creation_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jax
2424
from jax.sharding import Mesh, AxisType
2525
from MaxText import maxtext_utils
26+
from MaxText import max_utils
2627
from MaxText import pyconfig
2728
from MaxText.layers import quantizations
2829
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
@@ -153,7 +154,10 @@ def create_sharded_state():
153154
with nn.logical_axis_rules(config.logical_axis_rules):
154155
sharded_state = create_sharded_state()
155156
model = nnx.merge(graphdef, sharded_state)
156-
157+
# print weights sharding info under debug sharding mode
158+
if config.debug_sharding:
159+
max_utils.print_non_trivial_mesh_axis(model.mesh)
160+
maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh)
157161
if config.load_parameters_path:
158162
try:
159163
ckptr = ocp.Checkpointer(

src/MaxText/train_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def main(argv: Sequence[str]) -> None:
228228
# print weights sharding info under debug sharding mode
229229
if config.debug_sharding:
230230
max_utils.print_non_trivial_mesh_axis(topology_mesh)
231-
maxtext_utils.print_state_mesh_shardings_params(shaped_train_args[0], state_mesh_shardings, topology_mesh)
231+
maxtext_utils.print_shardings_params(shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh)
232232

233233
# Compile
234234
print("Jitting and compiling train step...", flush=True)

src/MaxText/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def setup_train_loop(config, recorder, devices=None):
218218
# print weights sharding info under debug sharding mode
219219
if config.debug_sharding:
220220
max_utils.print_non_trivial_mesh_axis(model.mesh)
221-
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh)
221+
maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh)
222222

223223
if config.use_dpo:
224224
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)

0 commit comments

Comments
 (0)