Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[DCU] Llama a8w8 inference performance optimization
  • Loading branch information
Deleter-D committed Jul 24, 2024
commit 7ae994dc51572567117ad24ee8b54ec956d72380
14 changes: 8 additions & 6 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ def stream_predict(self, inputs: dict[str, paddle.Tensor]):
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=get_eos_token_id(self.tokenizer, self.generation_config),
pad_token_id=self.tokenizer.pad_token_id,
decode_strategy="greedy_search"
if self.config.top_k == 1 and self.config.top_p == 1.0
else self.config.decode_strategy,
decode_strategy=(
"greedy_search" if self.config.top_k == 1 and self.config.top_p == 1.0 else self.config.decode_strategy
),
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
Expand Down Expand Up @@ -1238,7 +1238,9 @@ def predict(self, input_texts: str | list[str], return_tokens=False):
def _preprocess(self, source):
BlockInferencePredictorMixin._preprocess(self, source)
for i, text in enumerate(source):
tokens = self.tokenizer(text, return_tensors="np", padding=False, truncation=True, max_length=(self.config.src_length))
tokens = self.tokenizer(
text, return_tensors="np", padding=False, truncation=True, max_length=(self.config.src_length)
)
input_ids = tokens["input_ids"][0]
length = len(input_ids)
need_block_nums = (
Expand Down Expand Up @@ -1650,8 +1652,8 @@ def predict():
target_texts.append("")

else:
source_texts = ["解释一下“温故而知新”", "你好,请问你是谁?"]
target_texts = ["", ""]
source_texts = ["你好,请问你是谁?"] * predictor_args.batch_size
target_texts = [""] * predictor_args.batch_size

batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size)
batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size)
Expand Down
41 changes: 29 additions & 12 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,11 @@ def init_weight_shape(self, config):
if config.trans_qkvw
else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
)
self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim]
self.linear_weight_shape = (
[self.num_heads * self.head_dim, self.embed_dim]
if config.trans_qkvw
else [self.embed_dim, self.num_heads * self.head_dim]
)
self.ffn1_weight_shape = (
[self.embed_dim, self.dim_feedforward * 2]
if self.activation.endswith("glu")
Expand Down Expand Up @@ -1264,13 +1268,14 @@ def get_weight_create_dype(self):
def init_weight_shape(self, config):
super().init_weight_shape(config)

self.linear_weight_shape = [self.embed_dim, self.num_heads * self.head_dim]
self.ffn1_weight_shape = (
[self.dim_feedforward * 2, self.embed_dim]
if self.activation.endswith("glu")
else [self.dim_feedforward, self.embed_dim]
)
self.ffn2_weight_shape = [self.embed_dim, self.dim_feedforward]
if not paddle.is_compiled_with_rocm():
self.linear_weight_shape = [self.embed_dim, self.num_heads * self.head_dim]
self.ffn1_weight_shape = (
[self.dim_feedforward * 2, self.embed_dim]
if self.activation.endswith("glu")
else [self.dim_feedforward, self.embed_dim]
)
self.ffn2_weight_shape = [self.embed_dim, self.dim_feedforward]

def compute_layernorm_before_qkv(self, src, i):
if i == 0:
Expand All @@ -1291,7 +1296,10 @@ def compute_layernorm_before_qkv(self, src, i):
return ln_out

def compute_qkv_linear(self, ln_out, i):
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i], False, True)
if paddle.is_compiled_with_rocm():
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i])
else:
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i], False, True)
return qkv_out

def compute_fmha(
Expand Down Expand Up @@ -1384,7 +1392,10 @@ def compute_mmha(self, qkv_out, caches, attn_mask, seq_lens, rotary_embs, rotary
)[0]

def compute_out_linear(self, fmha_out, i):
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i], False, True)
if paddle.is_compiled_with_rocm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把rocm需要不转置的理由在PR描述里说下吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加说明

out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i])
else:
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i], False, True)
return dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype)

def compute_ffn_layernorm(self, out_linear_out, residual_input, i):
Expand Down Expand Up @@ -1421,10 +1432,16 @@ def compute_activation(self, ffn1_out, i):
)

def compute_ffn1(self, tmp_out, i):
return paddle.matmul(tmp_out, self.ffn1_weights[i], False, True)
if paddle.device.is_compiled_with_rocm():
return paddle.matmul(tmp_out, self.ffn1_weights[i])
else:
return paddle.matmul(tmp_out, self.ffn1_weights[i], False, True)

def compute_ffn2(self, ffn1_out, i):
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i], False, True)
if paddle.device.is_compiled_with_rocm():
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i])
else:
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i], False, True)
ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype)
return ffn2_out

Expand Down
85 changes: 61 additions & 24 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def __init__(self, config: LlamaConfig):
use_neox_rotary_style=True,
use_dynamic_cachekv_quant=config.use_cachekv_int8 == "dynamic",
rank_id=config.tensor_parallel_rank,
trans_qkvw=(True if not paddle.is_compiled_with_rocm() else False),
)

self.set_transformer_block(transformer_config)
Expand Down Expand Up @@ -751,25 +752,42 @@ def set_state_dict(self, state_dict):
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]
concated_qkv_weight = (
np.concatenate(
if paddle.is_compiled_with_rocm():
concated_qkv_weight = np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
).reshape(
self.hidden_size,
(
self.num_attention_heads // self.config.tensor_parallel_degree
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
)
* (head_size),
self.hidden_size,
)
)
else:
concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
unfused_state_dict["self_attn.k_proj.weight"],
unfused_state_dict["self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
(
self.num_attention_heads // self.config.tensor_parallel_degree
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
)
* (head_size),
self.hidden_size,
)
)
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
concated_ffn1_weight = np.concatenate(
split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1
Expand Down Expand Up @@ -816,14 +834,21 @@ def set_state_dict(self, state_dict):
self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight_tensor)
self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale_tensor)
elif self.quant_type == "a8w8":
self.transformer_block.linear_weights[idx].set_value(
paddle.cast(
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]).transpose(
(1, 0)
),
"int8",
if paddle.is_compiled_with_rocm():
self.transformer_block.linear_weights[idx].set_value(
paddle.cast(
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]), "int8"
)
)
else:
self.transformer_block.linear_weights[idx].set_value(
paddle.cast(
paddle.to_tensor(
state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]
).transpose((1, 0)),
"int8",
)
)
)
else:
self.transformer_block.linear_weights[idx].set_value(
linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype)
Expand All @@ -839,9 +864,14 @@ def set_state_dict(self, state_dict):
self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight_tensor)
self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale_tensor)
elif self.quant_type == "a8w8":
self.transformer_block.ffn1_weights[idx].set_value(
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
)
if paddle.is_compiled_with_rocm():
self.transformer_block.ffn1_weights[idx].set_value(
paddle.cast(paddle.to_tensor(concated_ffn1_weight), "int8")
)
else:
self.transformer_block.ffn1_weights[idx].set_value(
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
)
else:
self.transformer_block.ffn1_weights[idx].set_value(
ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype)
Expand All @@ -858,14 +888,21 @@ def set_state_dict(self, state_dict):
self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight_tensor)
self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale_tensor)
elif self.quant_type == "a8w8":
self.transformer_block.ffn2_weights[idx].set_value(
paddle.cast(
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]).transpose(
(1, 0)
),
"int8",
if paddle.is_compiled_with_rocm():
self.transformer_block.ffn2_weights[idx].set_value(
paddle.cast(
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]), "int8"
)
)
else:
self.transformer_block.ffn2_weights[idx].set_value(
paddle.cast(
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]).transpose(
(1, 0)
),
"int8",
)
)
)
else:
self.transformer_block.ffn2_weights[idx].set_value(
ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype)
Expand Down