Skip to content

Commit 29495b2

Browse files
Jiajun-Jicmcamdy
andauthored
[XPU] Unify Spec and non-spec branch.(#6947) (#7180)
* [XPU] cherry-pick PR-6947 * [XPU] use unified_update_model_status. * refactor xpu_model_runner. * refactor sampler. * fix codestyle. * Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path. * fix codestyle. * replace output_padding_offset with is_speculative flag in gather_next_token. * rename hiddden_states. * unify cu_seqlens_q_output and batch_id_per_token_output init. --------- Co-authored-by: cmcamdy <1027740945@qq.com>
1 parent 17002ed commit 29495b2

File tree

9 files changed

+226
-149
lines changed

9 files changed

+226
-149
lines changed

custom_ops/xpu_ops/src/ops/gather_next_token.cc

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ std::vector<paddle::Tensor> GatherNextToken(
3232
const paddle::Tensor& encoder_batch_map_cpu,
3333
const paddle::Tensor& decoder_batch_map_cpu,
3434
const paddle::Tensor& len_info_cpu,
35-
const paddle::optional<paddle::Tensor>& output_padding_offset,
35+
bool is_speculative,
3636
int max_bsz) {
3737
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
3838
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -73,7 +73,7 @@ std::vector<paddle::Tensor> GatherNextToken(
7373
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
7474

7575
paddle::Tensor out;
76-
if (output_padding_offset) {
76+
if (is_speculative) {
7777
int need_delete_token_num = 0;
7878
if (enc_batch > 0) {
7979
need_delete_token_num =
@@ -88,7 +88,7 @@ std::vector<paddle::Tensor> GatherNextToken(
8888
return {out};
8989
}
9090

91-
if (output_padding_offset) {
91+
if (is_speculative) {
9292
int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
9393
ctx,
9494
reinterpret_cast<const XPUType*>(x.data<data_t>()),
@@ -124,14 +124,10 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
124124
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
125125
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
126126
const std::vector<int64_t>& len_info_cpu_shape,
127-
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
128-
// if (output_padding_offset_shape) {
129-
// PD_THROW("speculative decoding is not supported in XPU.");
130-
// }
131-
// int64_t bsz = cum_offsets_shape[0];
127+
bool is_speculative) {
132128
int64_t bsz = 0;
133129
int64_t dim_embed = x_shape[1];
134-
if (output_padding_offset_shape) {
130+
if (is_speculative) {
135131
return {{-1, dim_embed}};
136132
} else {
137133
return {{bsz, dim_embed}};
@@ -148,8 +144,7 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
148144
const paddle::DataType& decoder_seq_lod_cpu_dtype,
149145
const paddle::DataType& encoder_batch_map_cpu_dtype,
150146
const paddle::DataType& decoder_batch_map_cpu_dtype,
151-
const paddle::DataType& len_info_cpu_dtype,
152-
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
147+
const paddle::DataType& len_info_cpu_dtype) {
153148
return {x_dtype};
154149
}
155150

@@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token)
163158
"decoder_seq_lod_cpu",
164159
"encoder_batch_map_cpu",
165160
"decoder_batch_map_cpu",
166-
"len_info_cpu",
167-
paddle::Optional("output_padding_offset")})
161+
"len_info_cpu"})
168162
.Outputs({"out"})
169-
.Attrs({"max_bsz: int"})
163+
.Attrs({"is_speculative: bool", "max_bsz: int"})
170164
.SetKernelFn(PD_KERNEL(GatherNextToken))
171165
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
172166
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));

custom_ops/xpu_ops/src/ops/pybind/pybind.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ std::vector<paddle::Tensor> GatherNextToken(
465465
const paddle::Tensor& encoder_batch_map_cpu,
466466
const paddle::Tensor& decoder_batch_map_cpu,
467467
const paddle::Tensor& len_info_cpu,
468-
const paddle::optional<paddle::Tensor>& output_padding_offset,
468+
bool is_speculative,
469469
int max_bsz);
470470

471471
std::vector<paddle::Tensor> GetImgBoundaries(
@@ -1035,7 +1035,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
10351035
py::arg("encoder_batch_map_cpu"),
10361036
py::arg("decoder_batch_map_cpu"),
10371037
py::arg("len_info_cpu"),
1038-
py::arg("output_padding_offset"),
1038+
py::arg("is_speculative"),
10391039
py::arg("max_bsz"),
10401040
"Gather next token for XPU");
10411041

@@ -1164,6 +1164,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
11641164
py::arg("max_draft_tokens"),
11651165
"Unified update model status");
11661166

1167+
m.def("verify_draft_tokens",
1168+
&VerifyDraftTokens,
1169+
py::arg("step_output_ids"),
1170+
py::arg("step_output_len"),
1171+
py::arg("step_input_ids"),
1172+
py::arg("target_tokens"),
1173+
py::arg("candidate_ids"),
1174+
py::arg("candidate_scores"),
1175+
py::arg("candidate_lens"),
1176+
py::arg("topp"),
1177+
py::arg("stop_flags"),
1178+
py::arg("seq_lens_encoder"),
1179+
py::arg("seq_lens_this_time"),
1180+
py::arg("end_tokens"),
1181+
py::arg("is_block_step"),
1182+
py::arg("cu_seqlens_q_output"),
1183+
py::arg("reasoning_status"),
1184+
py::arg("max_dec_len"),
1185+
py::arg("step_idx"),
1186+
py::arg("max_seq_len"),
1187+
py::arg("verify_window"),
1188+
py::arg("verify_strategy"),
1189+
py::arg("reject_all"),
1190+
py::arg("accept_all"),
1191+
"Perform speculative verification for decoding v2");
1192+
11671193
m.def("mtp_step_paddle",
11681194
&MTPStepPaddle,
11691195
py::arg("base_model_stop_flags"),

custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,6 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
766766
const int eos_token_id_len,
767767
const int inject_len,
768768
const bool splitwise_role_is_decode);
769-
770769
DLL_EXPORT int verify_draft_tokens(
771770
api::Context* ctx,
772771
// Core I/O

custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525

2626

27-
def _run_test_base(seq_lens_this_time_data, output_padding_offset):
27+
def _run_test_base(seq_lens_this_time_data, is_speculative):
2828
"""
2929
通用的基础测试执行函数,包含了两个场景共有的逻辑。
3030
"""
@@ -120,7 +120,7 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
120120
encoder_batch_map_cpu,
121121
decoder_batch_map_cpu,
122122
len_info_cpu,
123-
output_padding_offset,
123+
is_speculative,
124124
-1,
125125
)
126126

@@ -136,14 +136,14 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
136136
encoder_batch_map_cpu,
137137
decoder_batch_map_cpu,
138138
len_info_cpu,
139-
output_padding_offset,
139+
is_speculative,
140140
-1,
141141
)
142142

143143
gather_out_np = gather_out.astype("float32").cpu().numpy()
144144
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()
145145

146-
if output_padding_offset is not None:
146+
if is_speculative:
147147
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
148148
else:
149149
for i in range(gather_out_cpu.shape[0]):
@@ -160,19 +160,14 @@ def test_mix_with_mtp(self):
160160
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
161161
print("\nRunning test: test_mix_with_mtp")
162162
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
163-
bsz = len(seq_lens_this_time_data)
164-
output_padding_offset = paddle.zeros(bsz, dtype="int32")
165-
166-
_run_test_base(seq_lens_this_time_data, output_padding_offset)
163+
_run_test_base(seq_lens_this_time_data, True)
167164
print("Test passed for scenario: With MTP")
168165

169166
def test_mix_without_mtp(self):
170167
"""测试非 MTP (Single-Token Prediction) 场景下的功能"""
171168
print("\nRunning test: test_mix_without_mtp")
172169
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
173-
output_padding_offset = None # 非 MTP 场景下,此参数为 None
174-
175-
_run_test_base(seq_lens_this_time_data, output_padding_offset)
170+
_run_test_base(seq_lens_this_time_data, False)
176171
print("Test passed for scenario: Without MTP")
177172

178173

fastdeploy/model_executor/forward_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ class XPUForwardMeta(ForwardMeta):
275275
hidden_states: Optional[paddle.Tensor] = None
276276

277277
is_draft: bool = False
278+
is_speculative: bool = False
278279
# max bs
279280
max_num_seqs: int = 0
280281

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 128 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,17 +1045,129 @@ def forward_cuda(
10451045
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
10461046
return sampler_output
10471047

1048-
def forward_xpu(
1048+
def _normal_sample_xpu(
1049+
self,
1050+
logits: paddle.Tensor,
1051+
probs: paddle.Tensor,
1052+
sampling_metadata: SamplingMetadata,
1053+
share_inputs: List[paddle.Tensor],
1054+
) -> SamplerOutput:
1055+
"""Normal sampling for NAIVE mode on XPU."""
1056+
top_p, top_k, topp_seed = padding_sampling_params(
1057+
sampling_metadata.top_p,
1058+
sampling_metadata.top_k,
1059+
sampling_metadata.seed,
1060+
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
1061+
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
1062+
)
1063+
_, next_tokens = top_k_top_p_sampling(
1064+
probs,
1065+
top_p=top_p,
1066+
top_k=top_k,
1067+
top_k_list=sampling_metadata.top_k_list,
1068+
topp_seed=topp_seed,
1069+
)
1070+
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
1071+
running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32")
1072+
share_inputs["accept_tokens"][:real_bsz, 0] = next_tokens.squeeze(-1)
1073+
share_inputs["accept_num"][:real_bsz] = running_mask
1074+
return SamplerOutput(
1075+
sampled_token_ids=share_inputs["accept_tokens"],
1076+
logprobs_tensors=None,
1077+
token_num_per_batch=share_inputs["accept_num"],
1078+
logits=logits,
1079+
)
1080+
1081+
def _verify_and_sample_xpu(
10491082
self,
10501083
logits: paddle.Tensor,
1084+
probs: paddle.Tensor,
10511085
sampling_metadata: SamplingMetadata,
10521086
max_model_len: int,
10531087
share_inputs: List[paddle.Tensor],
10541088
accept_all_drafts: bool = False,
10551089
reject_all_drafts: bool = False,
1056-
) -> paddle.Tensor:
1057-
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
1090+
) -> SamplerOutput:
1091+
"""Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens."""
1092+
from fastdeploy.model_executor.ops.xpu import (
1093+
top_p_candidates,
1094+
verify_draft_tokens,
1095+
)
1096+
1097+
target_tokens = None
1098+
candidate_ids, candidate_scores, candidate_lens = None, None, None
10581099

1100+
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
1101+
top_p, top_k, topp_seed = padding_sampling_params(
1102+
sampling_metadata.top_p,
1103+
sampling_metadata.top_k,
1104+
sampling_metadata.seed,
1105+
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
1106+
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
1107+
)
1108+
_, target_tokens = top_k_top_p_sampling(
1109+
probs,
1110+
top_p=top_p,
1111+
top_k=top_k,
1112+
top_k_list=sampling_metadata.top_k_list,
1113+
topp_seed=topp_seed,
1114+
)
1115+
elif self.verify_strategy == VerifyStrategy.GREEDY:
1116+
target_tokens = paddle.argmax(probs, axis=-1)
1117+
elif self.verify_strategy == VerifyStrategy.TOPP:
1118+
candidate_scores, candidate_ids, candidate_lens = top_p_candidates(
1119+
probs,
1120+
sampling_metadata.top_p,
1121+
share_inputs["batch_id_per_token_output"],
1122+
self.speculative_max_candidate_len,
1123+
max_model_len,
1124+
)
1125+
else:
1126+
raise ValueError(f"Unknown verify strategy: {self.verify_strategy}")
1127+
1128+
final_accept_all = self.config_accept_all or accept_all_drafts
1129+
final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode
1130+
1131+
verify_draft_tokens(
1132+
share_inputs["accept_tokens"],
1133+
share_inputs["accept_num"],
1134+
share_inputs["draft_tokens"],
1135+
target_tokens,
1136+
candidate_ids,
1137+
candidate_scores,
1138+
candidate_lens,
1139+
sampling_metadata.top_p,
1140+
share_inputs["stop_flags"],
1141+
share_inputs["seq_lens_encoder"],
1142+
share_inputs["seq_lens_this_time"],
1143+
sampling_metadata.eos_token_ids,
1144+
share_inputs["is_block_step"],
1145+
share_inputs["cu_seqlens_q_output"],
1146+
share_inputs["reasoning_status"],
1147+
share_inputs["max_dec_len"],
1148+
share_inputs["step_idx"],
1149+
max_model_len,
1150+
self.speculative_verify_window,
1151+
self.verify_strategy.value,
1152+
final_reject_all,
1153+
final_accept_all,
1154+
)
1155+
return SamplerOutput(
1156+
sampled_token_ids=share_inputs["accept_tokens"],
1157+
logprobs_tensors=None,
1158+
token_num_per_batch=share_inputs["accept_num"],
1159+
logits=logits,
1160+
)
1161+
1162+
def forward_xpu(
1163+
self,
1164+
logits: paddle.Tensor,
1165+
sampling_metadata: SamplingMetadata,
1166+
max_model_len: int,
1167+
share_inputs: List[paddle.Tensor],
1168+
accept_all_drafts: bool = False,
1169+
reject_all_drafts: bool = False,
1170+
) -> SamplerOutput:
10591171
logits = apply_speculative_penalty_multi_scores(
10601172
sampling_metadata.token_ids_all,
10611173
sampling_metadata.prompt_lens,
@@ -1078,61 +1190,19 @@ def forward_xpu(
10781190

10791191
probs = F.softmax(logits)
10801192

1081-
top_p, top_k, topp_seed = padding_sampling_params(
1082-
sampling_metadata.top_p,
1083-
sampling_metadata.top_k,
1084-
sampling_metadata.seed,
1085-
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
1086-
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
1087-
)
1088-
_, sampled_token_ids = top_k_top_p_sampling(
1089-
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
1090-
)
1091-
1092-
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
1093-
probs,
1094-
sampling_metadata.top_p,
1095-
share_inputs["batch_id_per_token_output"],
1096-
self.speculative_max_candidate_len,
1097-
max_model_len,
1098-
)
1099-
1100-
speculate_verify(
1101-
sampled_token_ids,
1102-
share_inputs["accept_tokens"],
1103-
share_inputs["accept_num"],
1104-
share_inputs["step_idx"],
1105-
share_inputs["stop_flags"],
1106-
share_inputs["seq_lens_encoder"],
1107-
share_inputs["seq_lens_decoder"],
1108-
share_inputs[
1109-
"draft_tokens"
1110-
], # Both input and output, need to write the last 1 token accepted to position 0.
1111-
share_inputs["seq_lens_this_time"],
1112-
verify_tokens,
1113-
verify_scores,
1114-
share_inputs["max_dec_len"],
1115-
sampling_metadata.eos_token_ids,
1116-
share_inputs["is_block_step"],
1117-
share_inputs["cu_seqlens_q_output"],
1118-
actual_candidate_len,
1119-
share_inputs["actual_draft_token_num"],
1120-
sampling_metadata.top_p,
1121-
max_model_len,
1122-
self.speculative_verify_window,
1123-
True, # enable_topp
1124-
(self.speculative_benchmark_mode or reject_all_drafts),
1125-
accept_all_drafts,
1126-
)
1127-
# TODO(chenhuan09): support return logprobs
1128-
token_ids = share_inputs["accept_tokens"]
1129-
sampler_output = SamplerOutput(
1130-
sampled_token_ids=token_ids,
1131-
logprobs_tensors=None,
1132-
token_num_per_batch=share_inputs["accept_num"],
1133-
cu_batch_token_offset=None,
1134-
)
1135-
return sampler_output
1193+
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
1194+
if is_naive:
1195+
return self._normal_sample_xpu(logits, probs, sampling_metadata, share_inputs)
1196+
else:
1197+
return self._verify_and_sample_xpu(
1198+
logits,
1199+
probs,
1200+
sampling_metadata,
1201+
max_model_len,
1202+
share_inputs,
1203+
accept_all_drafts,
1204+
reject_all_drafts,
1205+
)
11361206

11371207

11381208
class MTPSampler(nn.Layer):

0 commit comments

Comments
 (0)