Skip to content

Commit 31e2a8b

Browse files
authored
[Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap (#7323)
* support mtp overlap in pd-split mode with insert_task overlap
1 parent 5ddd1af commit 31e2a8b

6 files changed

Lines changed: 351 additions & 122 deletions

File tree

fastdeploy/eplb/async_expert_loader.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,23 @@
2424
import paddle
2525

2626
try:
27-
from cuda import cudart
28-
except ImportError:
27+
import cuda as _cuda_pkg
28+
29+
_cuda_ver = getattr(_cuda_pkg, "__version__", None)
30+
if _cuda_ver is None:
31+
# cuda-python >= 13.x 无顶层 __version__,通过 cuda-bindings 子包判断
32+
import importlib.metadata as _meta
33+
34+
_cuda_ver = _meta.version("cuda-bindings")
35+
_cuda_major = int(_cuda_ver.split(".")[0])
36+
if _cuda_major >= 13:
37+
from cuda.bindings import runtime as cudart
38+
else:
39+
from cuda import cudart
40+
except Exception as _e:
41+
import warnings
42+
43+
warnings.warn(f"cuda-python import failed, async_expert_loader will be unavailable: {_e}")
2944
cudart = None
3045

3146
from fastdeploy.config import EPLBConfig
@@ -98,6 +113,7 @@ def create_mmap(model_name: List, ep_rank: int, ep_size: int, shm_uuid: str, epl
98113
raise ImportError(
99114
"cuda-python not installed. Install the version matching your CUDA toolkit:\n"
100115
" CUDA 12.x → pip install cuda-python==12.*\n"
116+
" CUDA 13.x → pip install cuda-python cuda-bindings\n"
101117
)
102118

103119
# Register memory with CUDA

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,36 +116,33 @@
116116

117117
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
118118

119-
if current_platform.is_cuda():
120-
121-
def async_set_value(tgt, src):
122-
if isinstance(src, (int, float, bool)):
123-
src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype)
124-
elif isinstance(src, (list, np.array)):
125-
dtype_str = str(tgt.dtype).split(".")[1]
126-
if isinstance(src, list):
127-
src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32")
119+
120+
def async_set_value(tgt, src):
121+
if isinstance(src, (int, float, bool)):
122+
src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype)
123+
elif isinstance(src, (list, np.ndarray)):
124+
dtype_str = str(tgt.dtype).split(".")[1]
125+
if isinstance(src, list):
126+
src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32")
127+
if current_platform.is_cuda():
128128
if str(src.dtype) != dtype_str:
129129
srt_tensor = paddle.empty(tgt.shape, dtype=str(src.dtype))
130130
src = custom_numpy_to_tensor(src, srt_tensor)
131131
else:
132132
return custom_numpy_to_tensor(src, tgt)
133-
elif isinstance(src, paddle.Tensor):
134-
pass
135133
else:
136-
raise ValueError("async_set_value unsupported src type: {}".format(type(src)))
137-
if src.shape != tgt.shape:
138-
src = src.reshape(tgt.shape)
139-
if src.dtype != tgt.dtype:
140-
src = src.cast(tgt.dtype)
141-
if src.place != tgt.place:
142-
src = src.to(tgt.place)
143-
tgt.copy_(src, blocking=False)
144-
145-
else:
146-
147-
def async_set_value(*args, **kwargs):
148-
raise RuntimeError("async_set_value is only available on CUDA")
134+
src = paddle.to_tensor(src, dtype=tgt.dtype)
135+
elif isinstance(src, paddle.Tensor):
136+
pass
137+
else:
138+
raise ValueError("async_set_value unsupported src type: {}".format(type(src)))
139+
if src.shape != tgt.shape:
140+
src = src.reshape(tgt.shape)
141+
if src.dtype != tgt.dtype:
142+
src = src.cast(tgt.dtype)
143+
if src.place != tgt.place:
144+
src = src.to(tgt.place)
145+
tgt.copy_(src, blocking=False)
149146

150147

151148
def pre_process(

fastdeploy/model_executor/xpu_pre_and_post_process.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@
5555
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
5656

5757

58+
def async_set_value(tgt, src):
59+
if isinstance(src, (int, float, bool)):
60+
src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype)
61+
elif isinstance(src, (list, np.ndarray)):
62+
dtype_str = str(tgt.dtype).split(".")[1]
63+
np_dtype = dtype_str if dtype_str != "bfloat16" else "float32"
64+
if isinstance(src, list):
65+
src = np.array(src, dtype=np_dtype)
66+
# TODO: support async_numpy_to_tensor
67+
src = paddle.to_tensor(src, dtype=tgt.dtype)
68+
elif isinstance(src, paddle.Tensor):
69+
pass
70+
else:
71+
raise ValueError("async_set_value unsupported src type: {}".format(type(src)))
72+
if src.shape != tgt.shape:
73+
src = src.reshape(tgt.shape)
74+
if src.dtype != tgt.dtype:
75+
src = src.cast(tgt.dtype)
76+
if src.place != tgt.place:
77+
src = src.to(tgt.place)
78+
tgt.copy_(src, blocking=False)
79+
80+
5881
def _build_stream_transfer_data(
5982
output_tokens: paddle.Tensor,
6083
pooler_outputs: List = None,

fastdeploy/spec_decode/mtp.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@
4949
share_external_data,
5050
update_attn_mask_offsets,
5151
)
52+
53+
# temporary solution
5254
from fastdeploy.model_executor.xpu_pre_and_post_process import (
55+
async_set_value,
5356
xpu_pre_process,
5457
xpu_process_output,
5558
)
@@ -483,28 +486,32 @@ def insert_tasks_v1(
483486
input_ids = request.prompt_token_ids + request.output_token_ids
484487

485488
self.model_inputs["input_ids_len"][idx] = length - 1
486-
self.model_inputs["pre_ids"][idx : idx + 1] = -1
489+
async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1)
487490
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
488491
idx : idx + 1, 1:length
489492
]
490-
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
491-
"input_ids"
492-
][idx : idx + 1, 1:length].cpu()
493+
# TODO: use token_all_ids replace with input_ids_cpu
494+
if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs:
495+
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
496+
"input_ids"
497+
][idx : idx + 1, 1:length].cpu()
493498
encoder_block_num = len(request.block_tables)
494-
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
495-
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
496-
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
497-
request.block_tables, dtype="int32"
499+
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
500+
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
501+
async_set_value(
502+
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
498503
)
499-
self.model_inputs["stop_flags"][idx : idx + 1] = False
500-
self.model_inputs["batch_drop"][idx : idx + 1] = False
501504

502-
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
505+
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False)
506+
async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False)
507+
508+
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length)
503509
self.exist_prefill_flag = True
504-
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
505-
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
506-
self.model_inputs["step_idx"][idx : idx + 1] = (
507-
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
510+
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index)
511+
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length)
512+
async_set_value(
513+
self.model_inputs["step_idx"][idx : idx + 1],
514+
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0,
508515
)
509516
if self.use_attn_mask_offset:
510517
inputs = request.multimodal_inputs
@@ -522,18 +529,19 @@ def insert_tasks_v1(
522529
if (
523530
self.fd_config.scheduler_config.splitwise_role == "decode"
524531
): # In PD, we continue to decode after P generates first token
525-
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
532+
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
526533
self.exist_prefill_flag = False
527-
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
528-
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
534+
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1)
529535
# NOTE(liuzichang):
530536
# extra 1 : P-D split need rollback one step
531-
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
537+
538+
async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0)
539+
async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1)
532540
# has_prefill_task = True
533541
elif request.task_type.value == RequestType.DECODE.value: # decode task
534542
encoder_block_num = len(request.block_tables)
535-
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
536-
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
543+
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
544+
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
537545
if current_platform.is_cuda():
538546
async_set_value(
539547
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
@@ -542,16 +550,13 @@ def insert_tasks_v1(
542550
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
543551
request.block_tables, dtype="int32"
544552
)
545-
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
546-
# has_decode_task = True
547-
# continue
548553
else:
549-
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
550-
self.model_inputs["stop_flags"][idx : idx + 1] = True
551-
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0
552-
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
553-
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
554-
self.model_inputs["is_block_step"][idx : idx + 1] = False
554+
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
555+
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True)
556+
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0)
557+
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0)
558+
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
559+
async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False)
555560
continue
556561

557562
# TODO(liuzichang): Solve splitewise-p bug to restore
@@ -1233,6 +1238,7 @@ def _update_status(self):
12331238
)
12341239

12351240
def _extend_draft_token_with_ngram_match(self):
1241+
# TODO: replace with gpu tensor
12361242
hybrid_mtp_ngram(
12371243
self.model_inputs["input_ids_cpu"].cuda(),
12381244
self.model_inputs["input_ids_len"].cuda(),

0 commit comments

Comments
 (0)