Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b072465
append_attention 0914
yuanlehome Sep 14, 2024
b915f95
paddle::empty to phi::allocator
yuanlehome Sep 14, 2024
9b1e1d8
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Sep 19, 2024
140a509
append_attn 0919
yuanlehome Sep 20, 2024
5272b6f
0920 fix split_kv_block
yuanlehome Sep 20, 2024
a42157d
my change for merge 4 to 1
yuanlehome Sep 23, 2024
bec8eef
fix prev
yuanlehome Sep 23, 2024
8dab056
merge zhenyun 0923
yuanlehome Sep 23, 2024
d5047b5
fix prev
yuanlehome Sep 23, 2024
006a467
fix var name
yuanlehome Sep 23, 2024
73e2c06
update
yuanlehome Sep 23, 2024
a8acb2b
fix config
yuanlehome Sep 24, 2024
ec46a89
fix
yuanlehome Sep 24, 2024
cb02ee5
fix append_attn
lizhenyun01 Sep 27, 2024
83a19a6
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Sep 27, 2024
37fc7da
fix --use_fake_parameter
yuanlehome Sep 27, 2024
a3b265b
refine paddle::empty(), fix memory error, support multi_stream for at…
yuanlehome Sep 29, 2024
68a09b6
fix and rename attention as append_attention
yuanlehome Sep 29, 2024
2bcd939
rename file
yuanlehome Sep 29, 2024
74941a0
fix
yuanlehome Sep 29, 2024
19a0bdb
encoder GQANEOX rope support
lizhenyun01 Oct 8, 2024
a9078cb
decoder a8w8c8 GQANEOX rope support
lizhenyun01 Oct 8, 2024
f64f962
merge get_block_shape and split_kv_block
yuanlehome Oct 8, 2024
7ba73f8
bf16 neox rope support
lizhenyun01 Oct 9, 2024
6837c23
fix diff
lizhenyun01 Oct 9, 2024
0a5ae96
separate compilation
lizhenyun01 Oct 9, 2024
e9cfc55
manual destroy stream
lizhenyun01 Oct 9, 2024
478c517
fix multi stream
yuanlehome Oct 10, 2024
aa1e96a
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 10, 2024
e8ddfe8
qwen/llama support weightonly
yuanlehome Oct 10, 2024
8798938
fix multi stream
yuanlehome Oct 10, 2024
f6a64d0
qwen-moe and mixtral support append_attn
yuanlehome Oct 10, 2024
2292780
refine code
yuanlehome Oct 11, 2024
036fb73
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 11, 2024
b85782d
decoder neox_rope_c4 support
lizhenyun01 Oct 11, 2024
9814578
instantiation of append_attn with float16
lizhenyun01 Oct 11, 2024
7a1f591
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 12, 2024
5c126ad
optimize cpu performance
yuanlehome Oct 12, 2024
2ef7c11
format code
yuanlehome Oct 12, 2024
4a4a4b4
c16/c8/c4 分离编译 加快编译速度
yuanlehome Oct 15, 2024
0e35a1e
fix bug
yuanlehome Oct 15, 2024
c5b4633
gqa_group_size -> kv_num_heads
yuanlehome Oct 15, 2024
ea8c07e
support speculate_attn
lizhenyun01 Oct 15, 2024
3789175
adjust network
yuanlehome Oct 16, 2024
6eacbca
cache_int4 -> cache_int4_zp
yuanlehome Oct 16, 2024
358115d
fix use_fake_parameter multi cards
yuanlehome Oct 17, 2024
30ac44c
fix speculate_decoder
lizhenyun01 Oct 17, 2024
4011d89
delete comment
lizhenyun01 Oct 17, 2024
7efff99
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
yuanlehome Oct 21, 2024
c30c112
Merge branch 'append_attn' of https://github.com/yuanlehome/PaddleNLP…
yuanlehome Oct 21, 2024
84a6864
fix ci
yuanlehome Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
my change for merge 4 to 1
  • Loading branch information
yuanlehome committed Sep 23, 2024
commit a42157d67b2fbf18d499abb301102259662889af
713 changes: 486 additions & 227 deletions csrc/gpu/append_attention.cu

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions csrc/gpu/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ __device__ __forceinline__ void block_produce_kv(
smem_t smem,
uint32_t* smem_offset,
T* gptr_base, // [max_block_num, num_heads, block_size, head_dim]
const int* block_table,
const int* block_tables,
const uint32_t kv_head_idx,
const uint32_t kv_n_stride,
const uint32_t kv_h_stride,
Expand All @@ -676,7 +676,7 @@ __device__ __forceinline__ void block_produce_kv(
kv_idx_base + (i * 4 * num_warps + ty * 4 + tx / 8);
const uint32_t kv_n_idx = row_now / block_size;
const uint32_t kv_bid = row_now % block_size;
T* gptr = gptr_base + __ldg(&block_table[kv_n_idx]) * kv_n_stride +
T* gptr = gptr_base + __ldg(&block_tables[kv_n_idx]) * kv_n_stride +
kv_head_idx * kv_h_stride + kv_bid * kv_b_stride +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
Expand All @@ -703,7 +703,7 @@ __device__ __forceinline__ void block_produce_kv(
const uint32_t row_now = kv_idx_base + (i * 16 + j * 4 + row_id_per_tx);
const uint32_t kv_n_idx = row_now / block_size;
const uint32_t kv_bid = row_now % block_size;
T* gptr = gptr_base + __ldg(&block_table[kv_n_idx]) * kv_n_stride +
T* gptr = gptr_base + __ldg(&block_tables[kv_n_idx]) * kv_n_stride +
kv_head_idx * kv_h_stride + kv_bid * kv_b_stride +
col_id_per_tx * num_elems_per_128b<T>();
#pragma unroll
Expand Down
24 changes: 12 additions & 12 deletions csrc/gpu/append_attn/append_attention_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ block_tables, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand All @@ -52,7 +52,7 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16;
const int *block_table_now = nullptr;

block_table_now = block_table + batch_id * max_block_num_per_seq;
block_table_now = block_tables + batch_id * max_block_num_per_seq;

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
Expand Down Expand Up @@ -491,7 +491,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ block_tables, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -520,7 +520,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids_per_batch[btid];
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
const int *block_table_now = block_tables + batch_id * max_block_num_per_seq;

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
Expand Down Expand Up @@ -1113,7 +1113,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ block_tables, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -1166,7 +1166,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16;
const int *block_table_now = nullptr;

block_table_now = block_table + batch_id * max_block_num_per_seq;
block_table_now = block_tables + batch_id * max_block_num_per_seq;

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
Expand Down Expand Up @@ -1731,7 +1731,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ block_tables, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -1782,7 +1782,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids_per_batch[btid];
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
const int *block_table_now = block_tables + batch_id * max_block_num_per_seq;

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
Expand Down Expand Up @@ -2457,7 +2457,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ block_tables, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -2503,7 +2503,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16;
const int *block_table_now = nullptr;

block_table_now = block_table + batch_id * max_block_num_per_seq;
block_table_now = block_tables + batch_id * max_block_num_per_seq;

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
Expand Down Expand Up @@ -3129,7 +3129,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ block_tables, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
Expand Down Expand Up @@ -3177,7 +3177,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids_per_batch[btid];
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
const int *block_table_now = block_tables + batch_id * max_block_num_per_seq;

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
Expand Down
48 changes: 24 additions & 24 deletions csrc/gpu/append_attn/append_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void MultiQueryAppendAttention(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const int num_blocks_x_cpu,
Expand All @@ -63,7 +63,7 @@ void MultiQueryAppendAttention(
const auto& cum_offsets_dims = cum_offsets.dims();
const uint32_t token_num = q_dims[0];
const uint32_t bsz = cum_offsets_dims[0];
const uint32_t max_block_num_per_seq = block_table.dims()[1];
const uint32_t max_block_num_per_seq = block_tables.dims()[1];

constexpr uint32_t num_warps = 4;
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
Expand Down Expand Up @@ -158,7 +158,7 @@ void MultiQueryAppendAttention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -212,7 +212,7 @@ void MultiQueryAppendAttention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -380,7 +380,7 @@ void MultiQueryAppendAttention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -448,7 +448,7 @@ void MultiQueryAppendAttention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -560,7 +560,7 @@ void MultiQueryAppendC8Attention(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const int num_blocks_x_cpu,
Expand Down Expand Up @@ -589,7 +589,7 @@ void MultiQueryAppendC8Attention(
const auto& cum_offsets_dims = cum_offsets.dims();
const uint32_t token_num = q_dims[0];
const uint32_t bsz = cum_offsets_dims[0];
const uint32_t max_block_num_per_seq = block_table.dims()[1];
const uint32_t max_block_num_per_seq = block_tables.dims()[1];

constexpr uint32_t num_warps = 4;
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
Expand Down Expand Up @@ -705,7 +705,7 @@ void MultiQueryAppendC8Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -770,7 +770,7 @@ void MultiQueryAppendC8Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -961,7 +961,7 @@ void MultiQueryAppendC8Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1034,7 +1034,7 @@ void MultiQueryAppendC8Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1144,7 +1144,7 @@ void MultiQueryAppendC4Attention(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const int num_blocks_x_cpu,
Expand All @@ -1169,7 +1169,7 @@ void MultiQueryAppendC4Attention(
const auto& cum_offsets_dims = cum_offsets.dims();
const uint32_t token_num = q_dims[0];
const uint32_t bsz = cum_offsets_dims[0];
const uint32_t max_block_num_per_seq = block_table.dims()[1];
const uint32_t max_block_num_per_seq = block_tables.dims()[1];

constexpr uint32_t num_warps = 4;
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
Expand Down Expand Up @@ -1275,7 +1275,7 @@ void MultiQueryAppendC4Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1336,7 +1336,7 @@ void MultiQueryAppendC4Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1513,7 +1513,7 @@ void MultiQueryAppendC4Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1587,7 +1587,7 @@ void MultiQueryAppendC4Attention(
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
block_tables.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
Expand Down Expand Up @@ -1698,7 +1698,7 @@ void CascadeAppendAttentionKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const std::string& cache_quant_type_str,
Expand Down Expand Up @@ -1760,7 +1760,7 @@ void CascadeAppendAttentionKernel(
seq_lens_encoder,
padding_offsets,
cum_offsets,
block_table,
block_tables,
batch_ids,
tile_ids_per_batch,
num_blocks,
Expand Down Expand Up @@ -1802,7 +1802,7 @@ void CascadeAppendAttentionKernel(
seq_lens_encoder,
padding_offsets,
cum_offsets,
block_table,
block_tables,
batch_ids,
tile_ids_per_batch,
num_blocks,
Expand Down Expand Up @@ -1841,7 +1841,7 @@ void CascadeAppendAttentionKernel(
seq_lens_encoder,
padding_offsets,
cum_offsets,
block_table,
block_tables,
batch_ids,
tile_ids_per_batch,
num_blocks,
Expand Down Expand Up @@ -1887,7 +1887,7 @@ template void CascadeAppendAttentionKernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const std::string& cache_quant_type_str,
Expand Down Expand Up @@ -1932,7 +1932,7 @@ template void CascadeAppendAttentionKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const std::string& cache_quant_type_str,
Expand Down
2 changes: 1 addition & 1 deletion csrc/gpu/append_attn/append_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void CascadeAppendAttentionKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const std::string& cache_quant_type_str,
Expand Down
Loading