Skip to content

Commit e69e140

Browse files
committed
fix rope bug && weight bug
1 parent 8a16ee4 commit e69e140

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

csrc/gpu/get_position_ids.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
__global__ void GetPositionIdsKernel(
1919
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
2020
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
21+
const int* seq_lens_this_time,
2122
int* position_ids, // 输出的一维 position_ids
2223
const int bsz) { // 批次大小
2324
// 当前线程索引(每个线程对应一个批次)
@@ -29,13 +30,14 @@ __global__ void GetPositionIdsKernel(
2930
for (int i = 0; i < tid; i++) {
3031
offset += seq_lens_encoder[i];
3132
if (seq_lens_decoder[i] > 0) {
32-
offset += 1;
33+
offset += seq_lens_this_time[i];
3334
}
3435
}
3536

3637
// 当前批次的 encoder 和 decoder 长度
3738
int encoder_len = seq_lens_encoder[tid];
3839
int decoder_len = seq_lens_decoder[tid];
40+
int seq_len_this_time = seq_lens_this_time[tid];
3941

4042
// 写入 encoder 的 position_ids
4143
for (int i = 0; i < encoder_len; i++) {
@@ -45,25 +47,29 @@ __global__ void GetPositionIdsKernel(
4547

4648
// 写入 decoder 的 position_ids
4749
if (decoder_len > 0) {
48-
position_ids[offset] = decoder_len; // 使用 decoder 长度本身
50+
for (int i = 0; i < seq_len_this_time; i++) {
51+
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
52+
}
4953
}
5054
}
5155

5256

5357
void GetPositionIds(const paddle::Tensor& seq_lens_encoder,
5458
const paddle::Tensor& seq_lens_decoder,
59+
const paddle::Tensor& seq_lens_this_time,
5560
const paddle::Tensor& position_ids) {
5661
const int bsz = seq_lens_encoder.shape()[0];
5762

5863
GetPositionIdsKernel<<<1, bsz, 0, position_ids.stream()>>>(
5964
seq_lens_encoder.data<int>(),
6065
seq_lens_decoder.data<int>(),
66+
seq_lens_this_time.data<int>(),
6167
const_cast<int*>(position_ids.data<int>()),
6268
bsz);
6369
}
6470

6571
PD_BUILD_OP(get_position_ids)
66-
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "position_ids"})
72+
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", "position_ids"})
6773
.Outputs({"position_ids_out"})
6874
.SetInplaceMap({{"position_ids", "position_ids_out"}})
6975
.SetKernelFn(PD_KERNEL(GetPositionIds));

paddlenlp/experimental/transformers/deepseek_v2/modeling.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,9 +1161,12 @@ def set_state_dict(self, state_dict):
11611161
self.lm_head.weight.set_value(
11621162
paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype)
11631163
)
1164-
# self.mtp.fc.weight.set_value(paddle.to_tensor(state_dict["llama.fc.weight"]).cast(self.lm_head.weight.dtype))
1165-
# self.mtp.fc.bias.set_value(paddle.to_tensor(state_dict["llama.fc.bias"]).cast(self.lm_head.weight.dtype))
1166-
1164+
1165+
self.mtp.enorm.weight.set_value(paddle.to_tensor(state_dict["deepseek_v3_mtp.enorm.weight"]).cast(self.lm_head.weight.dtype))
1166+
self.mtp.hnorm.weight.set_value(paddle.to_tensor(state_dict["deepseek_v3_mtp.hnorm.weight"]).cast(self.lm_head.weight.dtype))
1167+
self.mtp.norm.weight.set_value(paddle.to_tensor(state_dict["deepseek_v3_mtp.norm.weight"]).cast(self.lm_head.weight.dtype))
1168+
self.mtp.eh_proj.weight.set_value(paddle.to_tensor(state_dict["deepseek_v3_mtp.eh_proj.weight"]).cast(self.lm_head.weight.dtype))
1169+
11671170
self.mtp.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
11681171

11691172
def forward(

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,13 +1310,14 @@ def pre_process(self, **kwargs):
13101310
if self.config.mla_config.use_mla():
13111311
seq_lens_encoder = kwargs.get("seq_lens_encoder", None)
13121312
seq_lens_decoder = kwargs.get("seq_lens_decoder", None)
1313-
position_ids_shape = paddle.sum(seq_lens_encoder) + paddle.sum(seq_lens_decoder > 0)
1313+
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
1314+
position_ids_shape = paddle.sum(seq_lens_this_time)
13141315
self.position_ids = paddle.zeros(shape=position_ids_shape, dtype=seq_lens_encoder.dtype)
13151316

13161317
from paddlenlp_ops import get_position_ids
13171318

13181319
# In-place operations that compute the position_ids.
1319-
get_position_ids(seq_lens_encoder, seq_lens_decoder, self.position_ids)
1320+
get_position_ids(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, self.position_ids)
13201321

13211322
def post_process(self, **kwargs):
13221323
time_step = kwargs.get("time_step", None)

0 commit comments

Comments
 (0)