Skip to content

Commit de0c5e6

Browse files
authored
[XPU] Split the block_attn operator into smaller operators (#6798)
* spliced block_attn * adapt to latest vllm * fix unit tests * delete mtp+cudagraph 4 cards test * fix vl model * fix mtp * fix slot mapping
1 parent 6b891da commit de0c5e6

12 files changed

Lines changed: 2891 additions & 131 deletions

File tree

custom_ops/xpu_ops/src/ops/block_attn.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
159159
if (use_neox_rotary_style) {
160160
pos_emb_type = "NEOX";
161161
} else if (rope_head_dim == head_dim / 2) {
162+
// vl model use this
162163
pos_emb_type = "HALF_HEAD_DIM";
163164
} else {
164165
pos_emb_type = "NORMAL";
@@ -984,7 +985,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
984985
return {block_attn_out};
985986
}
986987

987-
std::vector<paddle::Tensor> BlockAttn(
988+
std::vector<paddle::Tensor> BlockAttnFused(
988989
const paddle::Tensor& qkv,
989990
const paddle::Tensor& key_cache,
990991
const paddle::Tensor& value_cache,
@@ -1008,6 +1009,8 @@ std::vector<paddle::Tensor> BlockAttn(
10081009
const paddle::Tensor& decoder_context_len_cache,
10091010
const paddle::Tensor& decoder_batch_map,
10101011
const paddle::Tensor& prefix_len,
1012+
const paddle::Tensor& slot_mapping_enc,
1013+
const paddle::Tensor& slot_mapping_dec,
10111014
const paddle::optional<paddle::Tensor>& k_scales,
10121015
const paddle::optional<paddle::Tensor>& v_scales,
10131016
const paddle::optional<paddle::Tensor>& k_scales_inv,
@@ -1067,7 +1070,7 @@ std::vector<paddle::Tensor> BlockAttn(
10671070
} else if (cache_dtype == paddle::DataType::INT8) {
10681071
APPLY_KERNEL(paddle::bfloat16, int8_t, paddle::bfloat16);
10691072
} else {
1070-
PD_THROW("block_attn not support cache_dtype==%d",
1073+
PD_THROW("block_attn_fused not support cache_dtype==%d",
10711074
static_cast<int>(cache_dtype));
10721075
return {};
10731076
}
@@ -1097,7 +1100,7 @@ std::vector<paddle::DataType> BlockAttnInferDtype(
10971100
return {qkv_dtype};
10981101
}
10991102

1100-
PD_BUILD_STATIC_OP(block_attn)
1103+
PD_BUILD_STATIC_OP(block_attn_fused)
11011104
.Inputs({"qkv",
11021105
"key_cache",
11031106
"value_cache",
@@ -1121,6 +1124,8 @@ PD_BUILD_STATIC_OP(block_attn)
11211124
"decoder_context_len_cache",
11221125
"decoder_batch_map",
11231126
"prefix_len",
1127+
"slot_mapping_enc",
1128+
"slot_mapping_dec",
11241129
paddle::Optional("k_scales"),
11251130
paddle::Optional("v_scales"),
11261131
paddle::Optional("k_scales_inv"),
@@ -1135,6 +1140,6 @@ PD_BUILD_STATIC_OP(block_attn)
11351140
paddle::Optional("cachekv_signal_thread_cpu")})
11361141
.Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"})
11371142
.Outputs({"block_attn_out"})
1138-
.SetKernelFn(PD_KERNEL(BlockAttn))
1143+
.SetKernelFn(PD_KERNEL(BlockAttnFused))
11391144
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
11401145
.SetInferDtypeFn(PD_INFER_DTYPE(BlockAttnInferDtype));

0 commit comments

Comments
 (0)