Skip to content

Commit 26c47c2

Browse files
authored
update attn_mask_q 2 (#7371)
1 parent 0ddb6e4 commit 26c47c2

1 file changed

Lines changed: 10 additions & 10 deletions

File tree

custom_ops/gpu_ops/get_attn_mask_q.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel(
2424
const int max_batch_size) {
2525
constexpr int VecSize = 4;
2626
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
27-
int startend_row_vec[4];
27+
int startend_row_vec[2];
2828
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2929
cudaGridDependencySynchronize();
3030
#endif
@@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel(
4949
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
5050

5151
startend_row_vec[0] = this_batch_q_end;
52-
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
53-
startend_row_vec[2] = 0;
54-
startend_row_vec[3] = this_batch_q_end;
52+
// startend_row_vec[1] = cu_seqlens_q[max_batch_size];
53+
// startend_row_vec[2] = 0;
54+
startend_row_vec[1] = this_batch_q_end;
5555
for (int this_batch_q_idx = this_batch_q_start;
5656
this_batch_q_idx < this_batch_q_end;
5757
++this_batch_q_idx) {
@@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel(
6262
: this_batch_q_idx - this_batch_q_start + kv_len -
6363
(this_batch_q_len);
6464
if (cache_k_idx <= append_mask_k_end) {
65-
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
65+
startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx);
6666
// 可提前跳出循环
6767
break;
6868
}
6969
}
70-
reinterpret_cast<int4*>(startend_row_indices_ptr +
71-
cu_seqlens_k_idx * 4)[0] =
72-
reinterpret_cast<int4*>(startend_row_vec)[0];
70+
reinterpret_cast<int2*>(startend_row_indices_ptr +
71+
cu_seqlens_k_idx * 2)[0] =
72+
reinterpret_cast<int2*>(startend_row_vec)[0];
7373
}
7474
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
7575
cudaTriggerProgrammaticLaunchCompletion();
@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> get_attn_mask_q(
8282
const paddle::optional<paddle::Tensor>& attn_mask_kv,
8383
const int kv_token_num) {
8484
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
85-
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
85+
{1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place());
8686
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
8787
constexpr int block_size = 512;
8888
int grid_size = div_up(kv_token_num, block_size);
@@ -123,7 +123,7 @@ std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
123123
const std::vector<int64_t>& cu_seqlens_k_shape,
124124
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
125125
const int kv_token_num) {
126-
return {{1, 1, kv_token_num, 4}};
126+
return {{1, 1, kv_token_num, 2}};
127127
}
128128

129129
PD_BUILD_STATIC_OP(get_attn_mask_q)

0 commit comments

Comments
 (0)