Skip to content

Commit e83d458

Browse files
authored
[Speculate Decoding] Fix step_idx semantics in limit_thinking and set_stop_value kernels (#7166)
- speculate_limit_thinking_content_length: update current_base_step to step_idx+1 (step_idx now records history count before current round); remove incorrect step_idx decrement on accept_num truncation; mark step_idx param as const. - speculate_set_stop_value_multi_seqs: fix can_stop gate to use step_idx_now+accept_num>=min_token_limit; fix skip check and pre_ids_idx formula (remove stale -accept_num offset); use <= condition so accept_idx maps directly to the accepted token that ends the stop sequence; fix accept_tokens index (remove -1). - Update unit tests for speculate_set_stop_value_multi_seqs kernel.
1 parent 73bd4ab commit e83d458

5 files changed

Lines changed: 229 additions & 106 deletions

File tree

custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
3434
int64_t* next_tokens, // [bs, tokens_per_step]
3535
const int* max_think_lens, // [bs]
3636
int* max_reply_lens, // [bs]
37-
int64_t* step_idx, // [bs]
37+
const int64_t* step_idx, // [bs]
3838
const int64_t* eos_token_ids, // [eos_len]
3939
int* limit_status, // [bs]
4040
int* accept_num, // [bs]
@@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
6868
int new_accept_num = original_accept_num;
6969

7070
// 本 step 的 token offset 对应的绝对 step
71-
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
71+
const int64_t current_base_step = step_idx[bid] + 1;
7272

7373
for (int token_offset = 0; token_offset < original_accept_num;
7474
token_offset++) {
@@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel(
100100
// inject_token_ids[0])
101101
if (status == 0 &&
102102
(current_step - 1) ==
103-
max_think_len) { // current_step - 1 是因为 speculate_verify 里
104-
// step_idx + 1 了
103+
max_think_len) { // current_step - 1 : 已输出 current_step-1
104+
// 个thinking token
105105
status = (inject_len > 0) ? 1 : done_status;
106106
}
107107
} else if (max_think_len == 0) {
@@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel(
181181
}
182182
}
183183

184-
// 更新 step_idx / accept_num(被截断的 token 需要回退
185-
// step_idx)
186-
const int discarded_tokens = original_accept_num - new_accept_num;
187-
if (discarded_tokens > 0) {
188-
step_idx[bid] -= discarded_tokens;
189-
}
190-
191184
accept_num[bid] = new_accept_num;
192185
limit_status[bid] = status;
193186
max_reply_lens[bid] = max_reply_len;
@@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength(
221214
const_cast<int64_t*>(next_tokens.data<int64_t>()),
222215
max_think_lens.data<int>(),
223216
const_cast<int*>(max_reply_lens.data<int>()),
224-
const_cast<int64_t*>(step_idx.data<int64_t>()),
217+
step_idx.data<int64_t>(),
225218
eos_token_ids.data<int64_t>(),
226219
const_cast<int*>(limit_status.data<int>()),
227220
const_cast<int*>(accept_num.data<int>()),

custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
5151
const int64_t step_idx_now = step_idx[bid];
5252
const int64_t min_token_limit = min_tokens[bid];
5353

54-
const bool can_stop = (step_idx_now >= min_token_limit);
54+
const bool can_stop = (step_idx_now + accept_num >= min_token_limit);
5555
if (!can_stop) return;
5656
if (!stop_flags[bid]) {
57-
int accept_idx = 0;
57+
/*
58+
accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based)
59+
accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾
60+
(pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。
61+
为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1]
62+
(当前轮最后一个 token),该 token 延迟到下一轮匹配。
63+
循环范围:accept_num > 0 时为 [-1, accept_num-2];
64+
accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。
65+
*/
66+
int accept_idx = -1;
5867
bool is_end = false;
59-
// 遍历起始位置
60-
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
68+
69+
// 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾
70+
// 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens
71+
// 中的匹配。两者共享同一套从后向前匹配逻辑。
72+
int loop_end = (accept_num > 0) ? accept_num - 2 : -1;
73+
for (; accept_idx <= loop_end && !is_end; accept_idx++) {
6174
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
6275
#ifdef DEBUG_SPEC_STOP_SEQS
6376
printf("num %d < stop_seq_len %d\n",
64-
step_idx_now - accept_num + accept_idx + 1,
77+
step_idx_now + accept_idx + 1,
6578
stop_seq_len);
6679
#endif
6780
continue;
6881
}
69-
// 遍历一个 stop_seqs
82+
// 从后向前匹配 stop_seq 的每个 token
7083
for (int i = stop_seq_len - 1; i >= 0; --i) {
7184
int64_t cur_token_idx = -1;
7285

73-
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
74-
if (stop_seq_len - 1 - i < accept_idx) {
86+
int offset = stop_seq_len - 1 - i;
87+
int accept_tokens_idx = accept_idx - offset;
88+
89+
if (accept_tokens_idx >= 0) {
7590
#ifdef DEBUG_SPEC_STOP_SEQS
7691
printf(
7792
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
78-
"accept_token_idx: "
79-
"%d\n",
93+
"accept_token_idx: %d\n",
8094
bid,
8195
tid,
8296
accept_idx,
83-
accept_idx - (stop_seq_len - 1 - i) - 1);
97+
accept_tokens_idx);
8498
#endif
85-
cur_token_idx =
86-
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
99+
cur_token_idx = accept_tokens_now[accept_tokens_idx];
87100
} else {
101+
int pre_ids_idx = step_idx_now + accept_tokens_idx;
88102
#ifdef DEBUG_SPEC_STOP_SEQS
89103
printf(
90104
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
91-
"accept_idx:%d. "
92-
"pre_id_idx: %ld\n",
105+
"accept_idx:%d. pre_id_idx: %d\n",
93106
bid,
94107
tid,
95108
step_idx_now,
96109
accept_idx,
97-
step_idx_now - accept_num + accept_idx -
98-
(stop_seq_len - 1 - i));
110+
pre_ids_idx);
99111
#endif
100-
int pre_ids_idx =
101-
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
102-
// EC3
103-
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
104-
// 导致异常结束
105-
if (pre_ids_idx <= 0) {
106-
break;
107-
}
112+
if (pre_ids_idx < 0) break;
108113
cur_token_idx = pre_ids_now[pre_ids_idx];
109114
}
110115
#ifdef DEBUG_SPEC_STOP_SEQS
@@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
126131
}
127132
if (is_end) {
128133
#ifdef DEBUG_SPEC_STOP_SEQS
129-
printf("bid:%d end with accept_idx %d", bid, accept_idx);
134+
printf("bid:%d end with accept_idx %d\n", bid, accept_idx);
130135
#endif
131-
132-
accept_nums[bid] = accept_idx;
133-
accept_tokens_now[accept_idx - 1] = end_ids[0];
134-
// stop_flags[bid] = true;
136+
// accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置
137+
accept_nums[bid] = accept_idx + 1;
138+
accept_tokens_now[accept_idx] = end_ids[0];
135139
}
136140
}
137141
}

custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
121121
int64_t *token_ids_all_now =
122122
&token_ids_all[batch_id * max_model_len + prompt_len];
123123
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
124-
int64_t base = cur_step_idx - output_len + 1;
124+
int64_t base = cur_step_idx - output_len;
125125
for (int i = 0; i < output_len; i++) {
126126
token_ids_all_now[base + i] = output_ids[i];
127127
}

0 commit comments

Comments
 (0)