Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b072465
append_attention 0914
yuanlehome Sep 14, 2024
b915f95
paddle::empty to phi::allocator
yuanlehome Sep 14, 2024
9b1e1d8
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Sep 19, 2024
140a509
append_attn 0919
yuanlehome Sep 20, 2024
5272b6f
0920 fix split_kv_block
yuanlehome Sep 20, 2024
a42157d
my change for merge 4 to 1
yuanlehome Sep 23, 2024
bec8eef
fix prev
yuanlehome Sep 23, 2024
8dab056
merge zhenyun 0923
yuanlehome Sep 23, 2024
d5047b5
fix prev
yuanlehome Sep 23, 2024
006a467
fix var name
yuanlehome Sep 23, 2024
73e2c06
update
yuanlehome Sep 23, 2024
a8acb2b
fix config
yuanlehome Sep 24, 2024
ec46a89
fix
yuanlehome Sep 24, 2024
cb02ee5
fix append_attn
lizhenyun01 Sep 27, 2024
83a19a6
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Sep 27, 2024
37fc7da
fix --use_fake_parameter
yuanlehome Sep 27, 2024
a3b265b
refine paddle::empty(), fix memory error, support multi_stream for at…
yuanlehome Sep 29, 2024
68a09b6
fix and rename attention as append_attention
yuanlehome Sep 29, 2024
2bcd939
rename file
yuanlehome Sep 29, 2024
74941a0
fix
yuanlehome Sep 29, 2024
19a0bdb
encoder GQANEOX rope support
lizhenyun01 Oct 8, 2024
a9078cb
decoder a8w8c8 GQANEOX rope support
lizhenyun01 Oct 8, 2024
f64f962
merge get_block_shape and split_kv_block
yuanlehome Oct 8, 2024
7ba73f8
bf16 neox rope support
lizhenyun01 Oct 9, 2024
6837c23
fix diff
lizhenyun01 Oct 9, 2024
0a5ae96
separate compilation
lizhenyun01 Oct 9, 2024
e9cfc55
manual destroy stream
lizhenyun01 Oct 9, 2024
478c517
fix multi stream
yuanlehome Oct 10, 2024
aa1e96a
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 10, 2024
e8ddfe8
qwen/llama support weightonly
yuanlehome Oct 10, 2024
8798938
fix multi stream
yuanlehome Oct 10, 2024
f6a64d0
qwen-moe and mixtral support append_attn
yuanlehome Oct 10, 2024
2292780
refine code
yuanlehome Oct 11, 2024
036fb73
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 11, 2024
b85782d
decoder neox_rope_c4 support
lizhenyun01 Oct 11, 2024
9814578
instantiation of append_attn with float16
lizhenyun01 Oct 11, 2024
7a1f591
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 12, 2024
5c126ad
optimize cpu performance
yuanlehome Oct 12, 2024
2ef7c11
format code
yuanlehome Oct 12, 2024
4a4a4b4
c16/c8/c4 分离编译 加快编译速度
yuanlehome Oct 15, 2024
0e35a1e
fix bug
yuanlehome Oct 15, 2024
c5b4633
gqa_group_size -> kv_num_heads
yuanlehome Oct 15, 2024
ea8c07e
support speculate_attn
lizhenyun01 Oct 15, 2024
3789175
adjust network
yuanlehome Oct 16, 2024
6eacbca
cache_int4 -> cache_int4_zp
yuanlehome Oct 16, 2024
358115d
fix use_fake_parameter multi cards
yuanlehome Oct 17, 2024
30ac44c
fix speculate_decoder
lizhenyun01 Oct 17, 2024
4011d89
delete comment
lizhenyun01 Oct 17, 2024
7efff99
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 21, 2024
c30c112
Merge branch 'append_attn' of https://github.com/yuanlehome/PaddleNLP…
yuanlehome Oct 21, 2024
84a6864
fix ci
yuanlehome Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
separate compilation
  • Loading branch information
lizhenyun01 committed Oct 9, 2024
commit 0a5ae9683cd3b65a6d704857a35fd9c54bbb8221
51 changes: 51 additions & 0 deletions csrc/gpu/append_attn/append_attention_bfloat16_bfloat16_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "append_attention_kernel.h"

template void CascadeAppendAttentionKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const std::string& cache_quant_type_str,
const int num_blocks,
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
const int num_heads,
const int kv_num_heads,
const int head_dim,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
51 changes: 51 additions & 0 deletions csrc/gpu/append_attn/append_attention_bfloat16_int8_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "append_attention_kernel.h"

template void CascadeAppendAttentionKernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim]
const paddle::Tensor& cache_v, // [max_block_num, num_heads, head_dim, block_size]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const std::string& cache_quant_type_str,
const int num_blocks,
const int block_shape_q,
const int max_seq_len,
const int max_dec_len,
const int num_heads,
const int kv_num_heads,
const int head_dim,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out);
74 changes: 17 additions & 57 deletions csrc/gpu/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "helper.h"
Expand Down Expand Up @@ -2030,39 +2043,17 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
__syncthreads();
#endif
OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#ifdef DEBUG_ATTN
__syncthreads();
if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 &&
blockIdx.x == gridDim.x - 1) {
printf("o_ptr end.\n");
}
__syncthreads();
#endif

uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) { // num_frags_y * 16 / (8[tid] *
// num_elems_per_128b<T>()[vec_per_thread])
#ifdef DEBUG_ATTN
__syncthreads();
if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 &&
blockIdx.x == gridDim.x - 1) {
printf("n_offset: %d, qo_upper_bound:%d.\n", n_offset, qo_upper_bound);
}
__syncthreads();
#endif

if (n_offset < qo_upper_bound) {
if constexpr (!partition_kv) {
#ifdef DEBUG_ATTN
__syncthreads();
if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 &&
blockIdx.x == gridDim.x - 1) {
printf(
"write_o_reg_gmem_multi_warps_shift_smooth_quant load start");
}
__syncthreads();
#endif

if (in_scale > 0.0) {
if (shift_bias) {
Load<T, VEC_SIZE>(shift_bias + shift_smooth_offset,
Expand All @@ -2074,14 +2065,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
Load<T, VEC_SIZE>(
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
&ori_out_vec);
#ifdef DEBUG_ATTN
__syncthreads();
if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 &&
blockIdx.x == gridDim.x - 1) {
printf("write_o_reg_gmem_multi_warps_shift_smooth_quant load end");
}
__syncthreads();
#endif

#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
StoreFunc<T, VEC_SIZE, OutT>()(ori_out_vec,
Expand Down Expand Up @@ -2112,23 +2096,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
__syncthreads();
#endif
}
#ifdef DEBUG_ATTN
__syncthreads();
if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 &&
blockIdx.x == gridDim.x - 1) {
printf("Store start");
}
__syncthreads();
#endif
Store<OutT, VEC_SIZE>(out_vec, o_ptr);
#ifdef DEBUG_ATTN
__syncthreads();
if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 &&
blockIdx.x == gridDim.x - 1) {
printf("Store end");
}
__syncthreads();
#endif
} else {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
Expand All @@ -2143,14 +2111,6 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
2 * num_frags_y;
// }
}
#ifdef DEBUG_ATTN
__syncthreads();
if (threadIdx.x == PRINT_TID && threadIdx.y == 0 && blockIdx.z == 0 &&
blockIdx.x == gridDim.x - 1) {
printf("kernel end");
}
__syncthreads();
#endif
}

template <uint32_t group_size,
Expand Down
Loading