4949 share_external_data ,
5050 update_attn_mask_offsets ,
5151 )
52+
53+ # temporary solution
5254 from fastdeploy .model_executor .xpu_pre_and_post_process import (
55+ async_set_value ,
5356 xpu_pre_process ,
5457 xpu_process_output ,
5558 )
@@ -483,28 +486,32 @@ def insert_tasks_v1(
483486 input_ids = request .prompt_token_ids + request .output_token_ids
484487
485488 self .model_inputs ["input_ids_len" ][idx ] = length - 1
486- self .model_inputs ["pre_ids" ][idx : idx + 1 ] = - 1
489+ async_set_value ( self .model_inputs ["pre_ids" ][idx : idx + 1 ], - 1 )
487490 self .model_inputs ["input_ids" ][idx : idx + 1 , : length - 1 ] = self .target_model_inputs ["input_ids" ][
488491 idx : idx + 1 , 1 :length
489492 ]
490- self .model_inputs ["input_ids_cpu" ][idx : idx + 1 , : length - 1 ] = self .target_model_inputs [
491- "input_ids"
492- ][idx : idx + 1 , 1 :length ].cpu ()
493+ # TODO: use token_all_ids replace with input_ids_cpu
494+ if getattr (self , "hybrid_mode" , False ) and "input_ids_cpu" in self .model_inputs :
495+ self .model_inputs ["input_ids_cpu" ][idx : idx + 1 , : length - 1 ] = self .target_model_inputs [
496+ "input_ids"
497+ ][idx : idx + 1 , 1 :length ].cpu ()
493498 encoder_block_num = len (request .block_tables )
494- self .model_inputs ["encoder_block_lens" ][idx : idx + 1 ] = encoder_block_num
495- self .model_inputs ["block_tables" ][idx : idx + 1 , :] = - 1
496- self . model_inputs [ "block_tables" ][ idx : idx + 1 , : encoder_block_num ] = np . array (
497- request . block_tables , dtype = "int32"
499+ async_set_value ( self .model_inputs ["encoder_block_lens" ][idx : idx + 1 ], encoder_block_num )
500+ async_set_value ( self .model_inputs ["block_tables" ][idx : idx + 1 , :], - 1 )
501+ async_set_value (
502+ self . model_inputs [ " block_tables" ][ idx : idx + 1 , : encoder_block_num ], request . block_tables
498503 )
499- self .model_inputs ["stop_flags" ][idx : idx + 1 ] = False
500- self .model_inputs ["batch_drop" ][idx : idx + 1 ] = False
501504
502- self .model_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = length
505+ async_set_value (self .model_inputs ["stop_flags" ][idx : idx + 1 ], False )
506+ async_set_value (self .model_inputs ["batch_drop" ][idx : idx + 1 ], False )
507+
508+ async_set_value (self .model_inputs ["seq_lens_encoder" ][idx : idx + 1 ], length )
503509 self .exist_prefill_flag = True
504- self .model_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = prefill_start_index
505- self .model_inputs ["seq_lens_this_time_buffer" ][idx : idx + 1 ] = length
506- self .model_inputs ["step_idx" ][idx : idx + 1 ] = (
507- len (request .output_token_ids ) if prefill_end_index >= len (input_ids ) else 0
510+ async_set_value (self .model_inputs ["seq_lens_decoder" ][idx : idx + 1 ], prefill_start_index )
511+ async_set_value (self .model_inputs ["seq_lens_this_time_buffer" ][idx : idx + 1 ], length )
512+ async_set_value (
513+ self .model_inputs ["step_idx" ][idx : idx + 1 ],
514+ len (request .output_token_ids ) if prefill_end_index >= len (input_ids ) else 0 ,
508515 )
509516 if self .use_attn_mask_offset :
510517 inputs = request .multimodal_inputs
@@ -522,18 +529,19 @@ def insert_tasks_v1(
522529 if (
523530 self .fd_config .scheduler_config .splitwise_role == "decode"
524531 ): # In PD, we continue to decode after P generates first token
525- self .model_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = 0
532+ async_set_value ( self .model_inputs ["seq_lens_encoder" ][idx : idx + 1 ], 0 )
526533 self .exist_prefill_flag = False
527- self .model_inputs ["recompute_token_num" ][idx : idx + 1 ] = 0
528- self .model_inputs ["seq_lens_this_time_buffer" ][idx : idx + 1 ] = length + 1
534+ async_set_value (self .model_inputs ["seq_lens_this_time_buffer" ][idx : idx + 1 ], length + 1 )
529535 # NOTE(liuzichang):
530536 # extra 1 : P-D split need rollback one step
531- self .model_inputs ["mask_rollback" ][idx : idx + 1 ] = 1
537+
538+ async_set_value (self .model_inputs ["recompute_token_num" ][idx : idx + 1 ], 0 )
539+ async_set_value (self .model_inputs ["mask_rollback" ][idx : idx + 1 ], 1 )
532540 # has_prefill_task = True
533541 elif request .task_type .value == RequestType .DECODE .value : # decode task
534542 encoder_block_num = len (request .block_tables )
535- self .model_inputs ["encoder_block_lens" ][idx : idx + 1 ] = encoder_block_num
536- self .model_inputs ["block_tables" ][idx : idx + 1 , :] = - 1
543+ async_set_value ( self .model_inputs ["encoder_block_lens" ][idx : idx + 1 ], encoder_block_num )
544+ async_set_value ( self .model_inputs ["block_tables" ][idx : idx + 1 , :], - 1 )
537545 if current_platform .is_cuda ():
538546 async_set_value (
539547 self .model_inputs ["block_tables" ][idx : idx + 1 , :encoder_block_num ], request .block_tables
@@ -542,16 +550,13 @@ def insert_tasks_v1(
542550 self .model_inputs ["block_tables" ][idx : idx + 1 , :encoder_block_num ] = np .array (
543551 request .block_tables , dtype = "int32"
544552 )
545- # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
546- # has_decode_task = True
547- # continue
548553 else :
549- self .model_inputs ["block_tables" ][idx : idx + 1 , :] = - 1
550- self .model_inputs ["stop_flags" ][idx : idx + 1 ] = True
551- self .model_inputs ["seq_lens_this_time_buffer" ][idx : idx + 1 ] = 0
552- self .model_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
553- self .model_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = 0
554- self .model_inputs ["is_block_step" ][idx : idx + 1 ] = False
554+ async_set_value ( self .model_inputs ["block_tables" ][idx : idx + 1 , :], - 1 )
555+ async_set_value ( self .model_inputs ["stop_flags" ][idx : idx + 1 ], True )
556+ async_set_value ( self .model_inputs ["seq_lens_this_time_buffer" ][idx : idx + 1 ], 0 )
557+ async_set_value ( self .model_inputs ["seq_lens_decoder" ][idx : idx + 1 ], 0 )
558+ async_set_value ( self .model_inputs ["seq_lens_encoder" ][idx : idx + 1 ], 0 )
559+ async_set_value ( self .model_inputs ["is_block_step" ][idx : idx + 1 ], False )
555560 continue
556561
557562 # TODO(liuzichang): Solve splitewise-p bug to restore
@@ -1233,6 +1238,7 @@ def _update_status(self):
12331238 )
12341239
12351240 def _extend_draft_token_with_ngram_match (self ):
1241+ # TODO: replace with gpu tensor
12361242 hybrid_mtp_ngram (
12371243 self .model_inputs ["input_ids_cpu" ].cuda (),
12381244 self .model_inputs ["input_ids_len" ].cuda (),
0 commit comments