Skip to content

Commit e0a1653

Browse files
lonelygshguanshihui]
andauthored
[Speculate Decoding] Fix bug of reasoning_phase_token_constraint kernel (#7349)
Co-authored-by: guanshihui] <guanshihui@baidu.com>
1 parent 7b0bace commit e0a1653

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

custom_ops/gpu_ops/reasoning_phase_token_constraint.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
// - In MTP mode, accept_num must be 1 in verify kernel
3939
//
4040
// Transition condition (x = 1 -> x = 2):
41-
// - step_idx >= 3
41+
// - step_idx >= 4
4242
// - pre_ids[-4:] exactly match:
4343
// "\n</think>\n\n"
4444
//
@@ -83,10 +83,10 @@ __global__ void update_reasoning_status_kernel(
8383
int64_t cur_step = step_idx[tid];
8484
const int64_t* pre_ids_now =
8585
token_ids_all + tid * max_seq_len + prompt_lens[tid];
86-
int64_t t0 = (cur_step >= 0) ? pre_ids_now[cur_step] : -1;
87-
int64_t t1 = (cur_step >= 1) ? pre_ids_now[cur_step - 1] : -1;
88-
int64_t t2 = (cur_step >= 2) ? pre_ids_now[cur_step - 2] : -1;
89-
int64_t t3 = (cur_step >= 3) ? pre_ids_now[cur_step - 3] : -1;
86+
int64_t t0 = (cur_step >= 1) ? pre_ids_now[cur_step - 1] : -1;
87+
int64_t t1 = (cur_step >= 2) ? pre_ids_now[cur_step - 2] : -1;
88+
int64_t t2 = (cur_step >= 3) ? pre_ids_now[cur_step - 3] : -1;
89+
int64_t t3 = (cur_step >= 4) ? pre_ids_now[cur_step - 4] : -1;
9090

9191
int32_t new_status = status;
9292

@@ -104,7 +104,7 @@ __global__ void update_reasoning_status_kernel(
104104
// x = 1 -> x = 2 (include think_end_id)
105105
// or x = 1 -> x = 3 (not include think_end_id)
106106
// Here must be serial judge
107-
if (new_status == 1 && cur_step >= 3) {
107+
if (new_status == 1 && cur_step >= 4) {
108108
if (t3 == line_break_id && t2 == think_end_id && t1 == line_break_id &&
109109
t0 == line_break_id) {
110110
new_status = 2;

tests/operators/test_reasoning_phase_token_constraint.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,24 @@ def setUp(self):
3737
# token_ids_all
3838
#
3939
# batch 0:
40-
# ... \n <think_end> \n \n → status 1 -> 2
40+
# step_idx=4, pre_ids_now[0..3]
41+
# pattern: \n <think_end> \n \n → status 1 -> 2
42+
# t3=pre_ids_now[0]=\n, t2=pre_ids_now[1]=<think_end>, t1=pre_ids_now[2]=\n, t0=pre_ids_now[3]=\n
4143
#
4244
# batch 1:
43-
# contains think_end, but pattern not complete → status 0 -> 1
45+
# contains think_end at pre_ids_now[2], but pattern not complete → status 0 -> 1
4446
# ------------------------
4547
token_ids_all = np.zeros((self.bs, self.max_seq_len), dtype=np.int64)
4648
self.prompt_lens = paddle.zeros([self.bs, 1], dtype="int64")
4749

48-
# batch 0
49-
token_ids_all[0, 1] = self.line_break_id
50-
token_ids_all[0, 2] = self.think_end_id
50+
# batch 0: pattern \n <think_end> \n \n at pre_ids_now[0..3]
51+
token_ids_all[0, 0] = self.line_break_id
52+
token_ids_all[0, 1] = self.think_end_id
53+
token_ids_all[0, 2] = self.line_break_id
5154
token_ids_all[0, 3] = self.line_break_id
52-
token_ids_all[0, 4] = self.line_break_id
5355

54-
# batch 1
55-
token_ids_all[1, 3] = self.think_end_id
56+
# batch 1: think_end at pre_ids_now[2]
57+
token_ids_all[1, 2] = self.think_end_id
5658

5759
self.token_ids_all = paddle.to_tensor(token_ids_all, dtype="int64")
5860
self.prompt_lens = paddle.zeros([self.bs, 1], dtype="int64")
@@ -167,11 +169,13 @@ def test_status_0_to_1_only(self):
167169

168170
# ------------------------
169171
# setup: only think_end appears
172+
# step_idx=4, pre_ids_now[0..3]
173+
# think_end at pre_ids_now[2] (cur_step - 2 = 4 - 2 = 2)
170174
# ------------------------
171175
token_ids_all = np.zeros((self.bs, self.max_seq_len), dtype=np.int64)
172176

173-
# batch 0: think_end at cur_step - 1
174-
token_ids_all[0, 3] = self.think_end_id
177+
# batch 0: think_end at pre_ids_now[2]
178+
token_ids_all[0, 2] = self.think_end_id
175179

176180
# batch 1: no think_end
177181
token_ids_all[1, :] = 0
@@ -424,13 +428,15 @@ def test_perf_bsz128_vocab100k_status2(self):
424428

425429
# ------------------------
426430
# token_ids_all: force 1 -> 2 pattern
431+
# step_idx=4, pre_ids_now[0..3]
432+
# pattern: t3=pre_ids_now[0]=\n, t2=pre_ids_now[1]=<think_end>, t1=pre_ids_now[2]=\n, t0=pre_ids_now[3]=\n
427433
# ------------------------
428434
token_ids_all = np.zeros((bs, max_seq_len), dtype=np.int64)
429435
for i in range(bs):
430-
token_ids_all[i, 1] = line_break_id
431-
token_ids_all[i, 2] = think_end_id
436+
token_ids_all[i, 0] = line_break_id
437+
token_ids_all[i, 1] = think_end_id
438+
token_ids_all[i, 2] = line_break_id
432439
token_ids_all[i, 3] = line_break_id
433-
token_ids_all[i, 4] = line_break_id
434440

435441
token_ids_all = paddle.to_tensor(token_ids_all, dtype="int64")
436442

0 commit comments

Comments
 (0)