Skip to content

Commit 3c8e0fa

Browse files
Merge pull request #2955 from AI-Hypercomputer:mohit/attn_dp
PiperOrigin-RevId: 859195426
2 parents aae7aea + 351eebc commit 3c8e0fa

11 files changed

Lines changed: 88 additions & 52 deletions

File tree

src/MaxText/common_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232

3333
BATCH = "activation_batch"
3434
BATCH_NO_EXP = "activation_batch_no_exp"
35+
36+
ATTN_LENGTH = "activation_attn_length"
37+
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"
38+
3539
LENGTH = "activation_length"
3640
LENGTH_NO_EXP = "activation_length_no_exp"
3741
PREFILL_LENGTH = "prefill_activation_length"
@@ -40,6 +44,7 @@
4044
Q_LORA_UP_PROJ = "q_lora_up_proj"
4145
KV_LENGTH = "activation_kv_length"
4246
KV_LORA_UP_PROJ = "kv_lora_up_proj"
47+
ATTN_EMBED = "activation_attn_embed"
4348
EMBED = "activation_embed"
4449
HEAD = "activation_heads"
4550
PREFILL_KV_BATCH = "activation_prefill_kv_batch"

src/MaxText/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ logical_axis_rules: [
393393
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
394394
['activation_length', ['sequence', 'context', 'expert']],
395395
['activation_length', ['context', 'expert']],
396+
['activation_attn_length', ['sequence', 'context', 'expert']],
397+
['activation_attn_length', ['context', 'expert']],
398+
['activation_attn_length_no_exp', ['sequence', 'context']],
399+
['activation_attn_length_no_exp', ['context']],
396400
['activation_length_no_exp', ['sequence', 'context']],
397401
['activation_length_no_exp', ['context']],
398402
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
@@ -401,6 +405,7 @@ logical_axis_rules: [
401405
['prefill_activation_length', ['sequence', 'context']],
402406
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
403407
['activation_kv_length', []],
408+
['activation_attn_embed', ['tensor', 'tensor_transpose']],
404409
['activation_embed', ['tensor', 'tensor_transpose']],
405410
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
406411
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],

src/MaxText/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,6 +2278,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22782278
"model": self.ici_tensor_parallelism,
22792279
"expert": self.ici_expert_parallelism,
22802280
"autoregressive": self.ici_autoregressive_parallelism,
2281+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
22812282
}
22822283
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
22832284

@@ -2295,6 +2296,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22952296
"model": self.dcn_tensor_parallelism,
22962297
"expert": self.dcn_expert_parallelism,
22972298
"autoregressive": self.dcn_autoregressive_parallelism,
2299+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
22982300
}
22992301
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
23002302

src/MaxText/configs/vllm.yml

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,49 @@ weight_dtype: bfloat16
2525

2626

2727
# -------------- Logical Axis Rules --------------
28-
mesh_axes: ['data', 'model', 'expert']
28+
mesh_axes: ['data', 'attn_dp', 'model', 'expert']
2929
logical_axis_rules: [
3030
['activation_batch', ['expert']],
3131
['activation_batch_no_exp', []],
3232
['activation_embed_and_logits_batch', ['expert']],
3333
['activation_embed_and_logits_batch_sequence', ['expert']],
3434
['activation_heads', ['model']],
3535
['activation_kv_heads', ['model']],
36+
['activation_attn_length', ['expert']],
37+
['activation_attn_length_no_exp', []],
3638
['activation_length', ['data', 'expert']],
37-
['activation_q_length', ['data', 'expert']],
38-
['activation_embed', ['model']],
39-
['activation_mlp', ['model']],
39+
['activation_length_no_exp', 'data'],
40+
['activation_q_length', ['expert']],
41+
['activation_attn_embed', 'model'],
42+
['activation_embed', ['model', 'attn_dp']],
43+
['activation_mlp', ['model', 'attn_dp']],
4044
['activation_kv', ['model']],
4145
['activation_prefill_kv_batch', ['expert']],
4246
['activation_kv_batch', ['expert']],
4347
['activation_kv_batch_no_exp', []],
4448
['activation_kv_head_dim', ['model']],
45-
['activation_vocab', ['model']],
46-
['activation_embed', ['model']],
49+
['activation_vocab', ['model', 'attn_dp']],
50+
['activation_norm_length', []],
4751
['activation_exp', ['expert']],
4852
['decode_batch', ['expert']],
49-
['mlp', ['model']],
50-
['mlp_no_fsdp', ['model']],
51-
['vocab', ['model']],
53+
['decode_length', []],
54+
['mlp', ['model', 'attn_dp']],
55+
['mlp_no_fsdp', ['model', 'attn_dp']],
56+
['vocab', ['model', 'attn_dp']],
5257
['heads', ['model']],
5358
['q_heads', ['model']],
5459
['kv_heads', ['model']],
5560
['kv_head_dim', []],
5661
['kv', []],
5762
['embed', ['expert']],
63+
['embed_tensor_transpose', ['attn_dp', 'model']],
5864
['embed_no_exp', []],
5965
['q_lora', ['expert']],
6066
['kv_lora', ['expert']],
61-
['norm', ['model']],
67+
['norm', []],
6268
['cache_heads', ['model']],
6369
['exp', ['expert']],
6470
['paged_kv_heads', ['model']],
6571
]
66-
data_sharding: [['data', 'model', 'expert']]
72+
data_sharding: [['data', 'attn_dp', 'model', 'expert']]
6773
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/MaxText/layers/attentions.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
D_KV,
3535
AxisNames,
3636
AxisIdxes,
37-
LENGTH,
38-
LENGTH_NO_EXP,
37+
ATTN_LENGTH,
38+
ATTN_LENGTH_NO_EXP,
3939
DType,
4040
Config,
4141
Array,
@@ -46,7 +46,7 @@
4646
KV_HEAD_DIM,
4747
KV_BATCH,
4848
KV_BATCH_NO_EXP,
49-
EMBED,
49+
ATTN_EMBED,
5050
MODEL_MODE_AUTOREGRESSIVE,
5151
MODEL_MODE_TRAIN,
5252
MODEL_MODE_PREFILL,
@@ -141,18 +141,18 @@ def attention_as_linen(
141141
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
142142
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
143143
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
144-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
145-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
146-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
147-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
148-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
149-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
150-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
151-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
152-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
153-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
154-
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
155-
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
144+
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
145+
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
146+
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
147+
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
148+
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
149+
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
150+
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
151+
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED),
152+
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
153+
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV),
154+
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
155+
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
156156
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
157157
decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
158158
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
@@ -300,18 +300,18 @@ def __init__(
300300
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
301301
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
302302
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
303-
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
304-
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
305-
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
306-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
307-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
308-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
309-
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
310-
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
311-
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
312-
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
313-
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
314-
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
303+
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
304+
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
305+
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
306+
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
307+
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
308+
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
309+
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
310+
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED),
311+
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
312+
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV),
313+
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
314+
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
315315
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
316316
decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
317317
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3),

src/MaxText/layers/moe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def __init__(
355355

356356
if self.config.attention == "vllm_rpa":
357357
# vLLM uses 'model' as the tensor parallelism axis name
358-
self._tensor_parallelism_name = "model"
358+
self._tensor_parallelism_name = ("model", "attn_dp")
359359
else:
360360
self._tensor_parallelism_name = "tensor"
361361

@@ -459,6 +459,11 @@ def get_expert_parallelism_size(self):
459459
return self.mesh.shape.get("expert", 1)
460460

461461
def get_tensor_parallelism_size(self):
462+
if isinstance(self._tensor_parallelism_name, tuple):
463+
size = 1
464+
for axis in self._tensor_parallelism_name:
465+
size *= self.mesh.shape.get(axis, 1)
466+
return size
462467
return self.mesh.shape.get(self._tensor_parallelism_name, 1)
463468

464469
def get_tensor_transpose_parallelism_size(self):

src/MaxText/maxtext_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,12 +1204,12 @@ def schedule(step):
12041204
return optax.join_schedules(pieces, boundaries)
12051205

12061206

1207-
def print_state_mesh_shardings_params(state, state_sharding, mesh):
1207+
def print_shardings_params(params, params_sharding, mesh):
12081208
"""Print state shardings."""
1209-
leaves_params, _ = jax.tree_util.tree_flatten_with_path(state.params)
1210-
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(state_sharding.params)
1209+
leaves_params, _ = jax.tree_util.tree_flatten_with_path(params)
1210+
leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding)
12111211
for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding):
1212-
path_str = "/".join(str(p.key) for p in path)
1212+
path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path)
12131213
shape = jax.typeof(leaf_val)
12141214
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
12151215
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")

src/MaxText/model_creation_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jax
2424
from jax.sharding import Mesh, AxisType
2525
from MaxText import maxtext_utils
26+
from MaxText import max_utils
2627
from MaxText import pyconfig
2728
from MaxText.layers import quantizations
2829
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
@@ -153,7 +154,10 @@ def create_sharded_state():
153154
with nn.logical_axis_rules(config.logical_axis_rules):
154155
sharded_state = create_sharded_state()
155156
model = nnx.merge(graphdef, sharded_state)
156-
157+
# print weights sharding info under debug sharding mode
158+
if config.debug_sharding:
159+
max_utils.print_non_trivial_mesh_axis(model.mesh)
160+
maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh)
157161
if config.load_parameters_path:
158162
try:
159163
ckptr = ocp.Checkpointer(

src/MaxText/train_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def main(argv: Sequence[str]) -> None:
228228
# print weights sharding info under debug sharding mode
229229
if config.debug_sharding:
230230
max_utils.print_non_trivial_mesh_axis(topology_mesh)
231-
maxtext_utils.print_state_mesh_shardings_params(shaped_train_args[0], state_mesh_shardings, topology_mesh)
231+
maxtext_utils.print_shardings_params(shaped_train_args[0].params, state_mesh_shardings.params, topology_mesh)
232232

233233
# Compile
234234
print("Jitting and compiling train step...", flush=True)

src/MaxText/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def setup_train_loop(config, recorder, devices=None):
218218
# print weights sharding info under debug sharding mode
219219
if config.debug_sharding:
220220
max_utils.print_non_trivial_mesh_axis(model.mesh)
221-
maxtext_utils.print_state_mesh_shardings_params(state, state_mesh_shardings, model.mesh)
221+
maxtext_utils.print_shardings_params(state.params, state_mesh_shardings.params, model.mesh)
222222

223223
if config.use_dpo:
224224
abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)

0 commit comments

Comments
 (0)