Skip to content

Conversation

@yuanlehome
Copy link
Collaborator

@yuanlehome yuanlehome commented Oct 11, 2024

PR types

New features

PR changes

Others

Description

大模型推理attention组网重构,新的append_attn方案相比旧方案有10%到90%的性能提升。

目前已支持了llama/qwen/qwen-moe/mixtral的推理。

使用方式,原推理脚本的 --block_attn选项改为--append_attn即可。

TODO:

  • fp8推理适配
  • 性能数据补充,稍后见llm docs

@paddle-bot
Copy link

paddle-bot bot commented Oct 11, 2024

Thanks for your contribution!

@codecov
Copy link

codecov bot commented Oct 11, 2024

Codecov Report

Attention: Patch coverage is 0% with 60 lines in your changes missing coverage. Please review.

Project coverage is 52.74%. Comparing base (fe8b527) to head (84a6864).
Report is 264 commits behind head on develop.

Files with missing lines Patch % Lines
...erimental/transformers/fused_transformer_layers.py 0.00% 38 Missing ⚠️
...dlenlp/experimental/transformers/qwen2/modeling.py 0.00% 8 Missing ⚠️
...dlenlp/experimental/transformers/llama/modeling.py 0.00% 7 Missing ⚠️
...enlp/experimental/transformers/mixtral/modeling.py 0.00% 5 Missing ⚠️
...lp/experimental/transformers/qwen2_moe/modeling.py 0.00% 1 Missing ⚠️
paddlenlp/experimental/transformers/utils.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #9244   +/-   ##
========================================
  Coverage    52.73%   52.74%           
========================================
  Files          661      661           
  Lines       107422   107371   -51     
========================================
- Hits         56653    56630   -23     
+ Misses       50769    50741   -28     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

static_cast<uint8_t>(quant_value2 + 128.0f);
}
// write k
// 大分块 lane_id / 4 / 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文注释删一删

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

太多了,留着无伤大雅吧

weight_scales_loader = EmptyWeightScale(
weight_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_head=self.num_attention_heads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个一直没有生效,会导致什么问题?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是调试时改的,仅仅改下命名,不会有什么影响

for i_layer, weight_scale in enumerate(v):
weight_scale = weight_scale.astype("float32")
if self.config.append_attn:
weight_scale = paddle.to_tensor(weight_scale).cast(paddle.get_default_dtype())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么 append_attn 下 可以不同 fp32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为kernel实现里是要求half精度的,访存 不同

print("***********Start Benchmark**********")

warmup_time = 10
test_time = 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个修改是为了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个修改没啥影响,没注意到提到commit上了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意下个pr恢复吧

def set_transformer_block(self, transformer_config):
if self.use_weight_only:
self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config)
elif "a8w8" in self.quant_type:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为什么删除?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为并没有支持,之前同学参考其他代码时一并copy过来了,我这里顺便给删掉

@ZHUI ZHUI merged commit 31c6b9a into PaddlePaddle:develop Oct 23, 2024
lvdongyi pushed a commit to lvdongyi/PaddleNLP that referenced this pull request Oct 23, 2024
* refine paddle::empty(), fix memory error, support multi_stream for attention

* fix and rename attention as append_attention

* rename file
---------

Co-authored-by: lizhenyun <[email protected]>
Co-authored-by: lizhenyun01 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants