@@ -156,7 +156,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
156156 rope_head_dim = rotary_embs.dims ()[4 ];
157157 }
158158 std::string pos_emb_type;
159- if (use_neox_rotary_style == true ) {
159+ if (use_neox_rotary_style) {
160160 pos_emb_type = " NEOX" ;
161161 } else if (rope_head_dim == head_dim / 2 ) {
162162 pos_emb_type = " HALF_HEAD_DIM" ;
@@ -342,12 +342,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
342342 value_cache.data <cdata_t >())),
343343 vsl.usual_lod_vp , // seq_lod
344344 vsl.slot_mapping_vp , // real_batch
345+ prefix_lens_vp, // start_tokens
345346 param.batch_size , // batch_size
346347 1 , // emb_batch_size
347348 rope_max_seqlen, // max_seqlen
348349 param.head_num ,
349350 param.kv_head_num ,
350351 param.head_dim ,
352+ rope_head_dim,
351353 param.max_batch_size ,
352354 block_size,
353355 max_block_per_seq,
@@ -586,7 +588,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
586588 ret = infer_ops::
587589 split_neox_cache_kv_encoder<XPU_XType, float , XPU_CType, int >(
588590 xpu_ctx->x_context (),
589- reinterpret_cast <const XPU_XType*>(qkv.data <data_t >()), // qkv
591+ reinterpret_cast <const XPU_XType*>(qkv.data <data_t >()) +
592+ total_enc_len * qkv_shape[qkv_shape.size () - 1 ], // qkv
590593 reinterpret_cast <const float *>(
591594 rotary_embs.data <float >()), // rotary_pos_emb
592595 reinterpret_cast <const int *>(
@@ -598,14 +601,16 @@ std::vector<paddle::Tensor> BlockAttnKernel(
598601 key_cache.data <cdata_t >())),
599602 const_cast <XPU_CType*>(reinterpret_cast <const XPU_CType*>(
600603 value_cache.data <cdata_t >())),
601- decoder_seq_lod_vp, // seq_lod
602- decoder_batch_map_vp, // real_batch
603- param.batch_size , // batch_size
604- 1 , // emb_batch_size
605- rope_max_seqlen, // max_seqlen
604+ decoder_seq_lod_vp, // seq_lod
605+ decoder_batch_map_vp, // real_batch
606+ decoder_context_len_cache_vp, // start_tokens
607+ param.batch_size , // batch_size
608+ 1 , // emb_batch_size
609+ rope_max_seqlen, // max_seqlen
606610 param.head_num ,
607611 param.kv_head_num ,
608612 param.head_dim ,
613+ rope_head_dim,
609614 param.max_batch_size ,
610615 block_size,
611616 max_block_per_seq,
@@ -806,6 +811,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
806811 param.head_num ,
807812 param.kv_head_num ,
808813 param.head_dim ,
814+ rope_head_dim,
809815 param.max_batch_size ,
810816 block_size,
811817 max_block_per_seq,
0 commit comments