@@ -56,6 +56,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
5656 const std::string& cache_quant_type_str,
5757 const bool use_neox_rotary_style,
5858 const int max_input_length,
59+ const float softmax_scale,
5960 const float quant_max_bound,
6061 const float quant_min_bound,
6162 const float out_linear_in_scale,
@@ -97,21 +98,21 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
9798 if (out_linear_in_scale > 0.0 ) {
9899 if (fabs (quant_max_bound - 127 .0f ) < 0.000001 ) {
99100 fmha_out = GetEmptyTensor (
100- {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims },
101+ {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims_v },
101102 paddle::DataType::INT8,
102103 qkv.place ());
103104 }
104105 else if (fabs (quant_max_bound - 448 .0f ) < 0.000001 ) {
105106 fmha_out = GetEmptyTensor (
106- {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims },
107+ {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims_v },
107108 paddle::DataType::FLOAT8_E4M3FN,
108109 qkv.place ());
109110 }else {
110111 PD_THROW (" Only supported attr of quant_max_bound in ['127.0', '448.0']." );
111112 }
112113 } else {
113114 fmha_out = GetEmptyTensor (
114- {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims },
115+ {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims_v },
115116 D,
116117 qkv.place ());
117118 }
@@ -203,6 +204,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
203204 encoder_block_shape_q,
204205 max_input_length,
205206 max_enc_len_this_time_data,
207+ softmax_scale,
206208 quant_max_bound,
207209 quant_min_bound,
208210 out_linear_in_scale,
@@ -240,6 +242,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
240242 encoder_block_shape_q,
241243 max_input_length,
242244 max_enc_len_this_time_data,
245+ softmax_scale,
243246 quant_max_bound,
244247 quant_min_bound,
245248 out_linear_in_scale,
@@ -282,6 +285,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
282285 encoder_block_shape_q,
283286 max_input_length,
284287 max_enc_len_this_time_data,
288+ softmax_scale,
285289 quant_max_bound,
286290 quant_min_bound,
287291 out_linear_in_scale,
@@ -428,6 +432,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
428432 decoder_block_shape_q,
429433 max_input_length,
430434 max_len_kv_data,
435+ softmax_scale,
431436 quant_max_bound,
432437 quant_min_bound,
433438 out_linear_in_scale,
@@ -465,6 +470,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
465470 decoder_block_shape_q,
466471 max_input_length,
467472 max_len_kv_data,
473+ softmax_scale,
468474 quant_max_bound,
469475 quant_min_bound,
470476 out_linear_in_scale,
@@ -508,6 +514,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
508514 decoder_block_shape_q,
509515 max_input_length,
510516 max_len_kv_data,
517+ softmax_scale,
511518 quant_max_bound,
512519 quant_min_bound,
513520 out_linear_in_scale,
@@ -565,6 +572,7 @@ std::vector<paddle::Tensor> AppendAttention(
565572 const std::string& cache_quant_type_str,
566573 const bool use_neox_rotary_style,
567574 const int max_input_length,
575+ const float softmax_scale,
568576 const float quant_max_bound,
569577 const float quant_min_bound,
570578 const float out_linear_in_scale,
@@ -578,9 +586,10 @@ std::vector<paddle::Tensor> AppendAttention(
578586 meta_data.token_nums = qkv_dims[0 ];
579587 meta_data.kv_num_heads = key_cache_dims[1 ];
580588 meta_data.head_dims = key_cache_dims[3 ];
581- const int total_num_head =
582- qkv_dims[qkv_dims.size () - 1 ] / meta_data.head_dims ;
583- meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads ;
589+ meta_data.head_dims_v = value_cache.dims ()[3 ];
590+ const int q_hidden_size =
591+ qkv_dims[qkv_dims.size () - 1 ] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v );
592+ meta_data.q_num_heads = q_hidden_size / meta_data.head_dims ;
584593
585594 meta_data.max_blocks_per_seq = block_tables.dims ()[1 ];
586595 meta_data.block_size = key_cache.dims ()[2 ];
@@ -626,6 +635,7 @@ std::vector<paddle::Tensor> AppendAttention(
626635 cache_quant_type_str,
627636 use_neox_rotary_style,
628637 max_input_length,
638+ softmax_scale,
629639 quant_max_bound,
630640 quant_min_bound,
631641 out_linear_in_scale,
@@ -672,6 +682,7 @@ std::vector<paddle::Tensor> AppendAttention(
672682 cache_quant_type_str,
673683 use_neox_rotary_style,
674684 max_input_length,
685+ softmax_scale,
675686 quant_max_bound,
676687 quant_min_bound,
677688 out_linear_in_scale,
@@ -719,6 +730,7 @@ std::vector<paddle::Tensor> AppendAttention(
719730 cache_quant_type_str,
720731 use_neox_rotary_style,
721732 max_input_length,
733+ softmax_scale,
722734 quant_max_bound,
723735 quant_min_bound,
724736 out_linear_in_scale,
@@ -764,6 +776,7 @@ std::vector<paddle::Tensor> AppendAttention(
764776 cache_quant_type_str,
765777 use_neox_rotary_style,
766778 max_input_length,
779+ softmax_scale,
767780 quant_max_bound,
768781 quant_min_bound,
769782 out_linear_in_scale,
@@ -821,10 +834,12 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
821834 const paddle::optional<std::vector<int64_t >>& out_linear_smooths_shape) {
822835 const int token_num = qkv_shape[0 ];
823836 const int kv_num_heads = key_cache_shape[1 ];
824- const int head_dim = key_cache_shape[3 ];
825- const int total_num_head = qkv_shape[qkv_shape.size () - 1 ] / head_dim;
826- const int num_heads = total_num_head - 2 * kv_num_heads;
827- return {{token_num, num_heads * head_dim}, qkv_shape};
837+ const int head_dim_qk = key_cache_shape[3 ];
838+ const int head_dim_v = value_cache_shape[3 ];
839+ const int q_hidden_size =
840+ qkv_shape[qkv_shape.size () - 1 ] - kv_num_heads * (head_dim_qk + head_dim_v);
841+ const int num_heads = q_hidden_size / head_dim_qk;
842+ return {{token_num, num_heads * head_dim_v}, qkv_shape};
828843}
829844
830845std::vector<paddle::DataType> AppendAttentionInferDtype (
@@ -865,6 +880,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
865880 const std::string& cache_quant_type_str,
866881 const bool use_neox_rotary_style,
867882 const int max_input_length,
883+ const float softmax_scale,
868884 const float quant_max_bound,
869885 const float quant_min_bound,
870886 const float out_linear_in_scale,
@@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
941957 " cache_quant_type: std::string" ,
942958 " use_neox_rotary_style: bool" ,
943959 " max_input_length: int" ,
960+ " softmax_scale: float" ,
944961 " quant_max_bound: float" ,
945962 " quant_min_bound: float" ,
946963 " out_linear_in_scale: float" ,
0 commit comments