@@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel(
2424 const int max_batch_size) {
2525 constexpr int VecSize = 4 ;
2626 const uint32_t tid = threadIdx .x , bid = blockIdx .x ;
27- int startend_row_vec[4 ];
27+ int startend_row_vec[2 ];
2828#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2929 cudaGridDependencySynchronize ();
3030#endif
@@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel(
4949 const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
5050
5151 startend_row_vec[0 ] = this_batch_q_end;
52- startend_row_vec[1 ] = cu_seqlens_q[max_batch_size];
53- startend_row_vec[2 ] = 0 ;
54- startend_row_vec[3 ] = this_batch_q_end;
52+ // startend_row_vec[1] = cu_seqlens_q[max_batch_size];
53+ // startend_row_vec[2] = 0;
54+ startend_row_vec[1 ] = this_batch_q_end;
5555 for (int this_batch_q_idx = this_batch_q_start;
5656 this_batch_q_idx < this_batch_q_end;
5757 ++this_batch_q_idx) {
@@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel(
6262 : this_batch_q_idx - this_batch_q_start + kv_len -
6363 (this_batch_q_len);
6464 if (cache_k_idx <= append_mask_k_end) {
65- startend_row_vec[3 ] = min (startend_row_vec[3 ], this_batch_q_idx);
65+ startend_row_vec[1 ] = min (startend_row_vec[1 ], this_batch_q_idx);
6666 // 可提前跳出循环
6767 break ;
6868 }
6969 }
70- reinterpret_cast <int4 *>(startend_row_indices_ptr +
71- cu_seqlens_k_idx * 4 )[0 ] =
72- reinterpret_cast <int4 *>(startend_row_vec)[0 ];
70+ reinterpret_cast <int2 *>(startend_row_indices_ptr +
71+ cu_seqlens_k_idx * 2 )[0 ] =
72+ reinterpret_cast <int2 *>(startend_row_vec)[0 ];
7373 }
7474#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
7575 cudaTriggerProgrammaticLaunchCompletion ();
@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> get_attn_mask_q(
8282 const paddle::optional<paddle::Tensor>& attn_mask_kv,
8383 const int kv_token_num) {
8484 paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor (
85- {1 , 1 , kv_token_num, 4 }, paddle::DataType::INT32, cu_seqlens_k.place ());
85+ {1 , 1 , kv_token_num, 2 }, paddle::DataType::INT32, cu_seqlens_k.place ());
8686 const int max_batch_size = cu_seqlens_k.dims ()[0 ] - 1 ;
8787 constexpr int block_size = 512 ;
8888 int grid_size = div_up (kv_token_num, block_size);
@@ -123,7 +123,7 @@ std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
123123 const std::vector<int64_t >& cu_seqlens_k_shape,
124124 const paddle::optional<std::vector<int64_t >>& attn_mask_kv_shape,
125125 const int kv_token_num) {
126- return {{1 , 1 , kv_token_num, 4 }};
126+ return {{1 , 1 , kv_token_num, 2 }};
127127}
128128
129129PD_BUILD_STATIC_OP (get_attn_mask_q)
0 commit comments