@@ -37,22 +37,24 @@ def setUp(self):
3737 # token_ids_all
3838 #
3939 # batch 0:
40- # ... \n <think_end> \n \n → status 1 -> 2
40+ # step_idx=4, pre_ids_now[0..3]
41+ # pattern: \n <think_end> \n \n → status 1 -> 2
42+ # t3=pre_ids_now[0]=\n, t2=pre_ids_now[1]=<think_end>, t1=pre_ids_now[2]=\n, t0=pre_ids_now[3]=\n
4143 #
4244 # batch 1:
43- # contains think_end, but pattern not complete → status 0 -> 1
45+ # contains think_end at pre_ids_now[2] , but pattern not complete → status 0 -> 1
4446 # ------------------------
4547 token_ids_all = np .zeros ((self .bs , self .max_seq_len ), dtype = np .int64 )
4648 self .prompt_lens = paddle .zeros ([self .bs , 1 ], dtype = "int64" )
4749
48- # batch 0
49- token_ids_all [0 , 1 ] = self .line_break_id
50- token_ids_all [0 , 2 ] = self .think_end_id
50+ # batch 0: pattern \n <think_end> \n \n at pre_ids_now[0..3]
51+ token_ids_all [0 , 0 ] = self .line_break_id
52+ token_ids_all [0 , 1 ] = self .think_end_id
53+ token_ids_all [0 , 2 ] = self .line_break_id
5154 token_ids_all [0 , 3 ] = self .line_break_id
52- token_ids_all [0 , 4 ] = self .line_break_id
5355
54- # batch 1
55- token_ids_all [1 , 3 ] = self .think_end_id
56+ # batch 1: think_end at pre_ids_now[2]
57+ token_ids_all [1 , 2 ] = self .think_end_id
5658
5759 self .token_ids_all = paddle .to_tensor (token_ids_all , dtype = "int64" )
5860 self .prompt_lens = paddle .zeros ([self .bs , 1 ], dtype = "int64" )
@@ -167,11 +169,13 @@ def test_status_0_to_1_only(self):
167169
168170 # ------------------------
169171 # setup: only think_end appears
172+ # step_idx=4, pre_ids_now[0..3]
173+ # think_end at pre_ids_now[2] (cur_step - 2 = 4 - 2 = 2)
170174 # ------------------------
171175 token_ids_all = np .zeros ((self .bs , self .max_seq_len ), dtype = np .int64 )
172176
173- # batch 0: think_end at cur_step - 1
174- token_ids_all [0 , 3 ] = self .think_end_id
177+ # batch 0: think_end at pre_ids_now[2]
178+ token_ids_all [0 , 2 ] = self .think_end_id
175179
176180 # batch 1: no think_end
177181 token_ids_all [1 , :] = 0
@@ -424,13 +428,15 @@ def test_perf_bsz128_vocab100k_status2(self):
424428
425429 # ------------------------
426430 # token_ids_all: force 1 -> 2 pattern
431+ # step_idx=4, pre_ids_now[0..3]
432+ # pattern: t3=pre_ids_now[0]=\n, t2=pre_ids_now[1]=<think_end>, t1=pre_ids_now[2]=\n, t0=pre_ids_now[3]=\n
427433 # ------------------------
428434 token_ids_all = np .zeros ((bs , max_seq_len ), dtype = np .int64 )
429435 for i in range (bs ):
430- token_ids_all [i , 1 ] = line_break_id
431- token_ids_all [i , 2 ] = think_end_id
436+ token_ids_all [i , 0 ] = line_break_id
437+ token_ids_all [i , 1 ] = think_end_id
438+ token_ids_all [i , 2 ] = line_break_id
432439 token_ids_all [i , 3 ] = line_break_id
433- token_ids_all [i , 4 ] = line_break_id
434440
435441 token_ids_all = paddle .to_tensor (token_ids_all , dtype = "int64" )
436442
0 commit comments