Skip to content

Commit 52eda7f

Browse files
authored
[Feature][MTP]support new speculative decoding method named hybrid mtp with ngram (PaddlePaddle#3610)
1 parent 0a0d295 commit 52eda7f

20 files changed

Lines changed: 454 additions & 571 deletions

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ void SpeculateVerify(
614614
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
615615
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
616616

617-
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
617+
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
618618
const paddle::Tensor &seq_lens_decoder,
619619
const paddle::Tensor &not_need_stop,
620620
const paddle::Tensor &draft_tokens,
@@ -659,6 +659,20 @@ void NgramMatch(const paddle::Tensor &input_ids,
659659
const int max_draft_tokens);
660660

661661

662+
void HybridMtpNgram(const paddle::Tensor &input_ids,
663+
const paddle::Tensor &input_ids_len,
664+
const paddle::Tensor &pre_ids,
665+
const paddle::Tensor &step_idx,
666+
const paddle::Tensor &draft_token_num,
667+
const paddle::Tensor &draft_tokens,
668+
const paddle::Tensor &seq_lens_this_time,
669+
const paddle::Tensor &seq_lens_decoder,
670+
const paddle::Tensor &max_dec_len,
671+
const int max_ngram_size,
672+
const int min_ngram_size,
673+
const int max_draft_tokens);
674+
675+
662676
// MTP
663677
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
664678
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -675,6 +689,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
675689
const paddle::Tensor& step_idx,
676690
const paddle::Tensor& not_need_stop,
677691
const paddle::Tensor& batch_drop,
692+
const paddle::Tensor& pre_ids,
678693
const paddle::Tensor& accept_tokens,
679694
const paddle::Tensor& accept_num,
680695
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -1121,7 +1136,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
11211136

11221137
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
11231138

1124-
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
1139+
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
11251140

11261141
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
11271142

@@ -1131,6 +1146,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
11311146

11321147
m.def("ngram_match", &NgramMatch, "ngram_match function");
11331148

1149+
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
1150+
11341151
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
11351152

11361153
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");

custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ __global__ void process_splitwise_prefill(
2626
int64_t* step_idx,
2727
bool* not_need_stop,
2828
bool* batch_drop,
29+
int64_t* pre_ids,
2930
const int64_t* accept_tokens,
3031
const int* accept_num,
3132
const int* base_model_seq_lens_this_time,
@@ -36,11 +37,12 @@ __global__ void process_splitwise_prefill(
3637
const bool* base_model_is_block_step,
3738
int64_t* base_model_draft_tokens,
3839
const int bsz,
39-
const int max_draft_token,
40+
const int num_model_step,
4041
const int accept_tokens_len,
4142
const int draft_tokens_len,
4243
const int input_ids_len,
43-
const int base_model_draft_tokens_len) {
44+
const int base_model_draft_tokens_len,
45+
const int pre_ids_len) {
4446
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
4547
__shared__ typename BlockReduce::TempStorage temp_storage;
4648
int64_t not_stop_flag = 0;
@@ -93,6 +95,7 @@ __global__ void draft_model_preprocess_kernel(
9395
int64_t* step_idx,
9496
bool* not_need_stop,
9597
bool* batch_drop,
98+
int64_t* pre_ids,
9699
const int64_t* accept_tokens,
97100
const int* accept_num,
98101
const int* base_model_seq_lens_this_time,
@@ -103,11 +106,12 @@ __global__ void draft_model_preprocess_kernel(
103106
const bool* base_model_is_block_step,
104107
int64_t* base_model_draft_tokens,
105108
const int bsz,
106-
const int max_draft_token,
109+
const int num_model_step,
107110
const int accept_tokens_len,
108111
const int draft_tokens_len,
109112
const int input_ids_len,
110-
const int base_model_draft_tokens_len) {
113+
const int base_model_draft_tokens_len,
114+
const int pre_ids_len) {
111115
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
112116
__shared__ typename BlockReduce::TempStorage temp_storage;
113117
int64_t not_stop_flag = 0;
@@ -124,6 +128,7 @@ __global__ void draft_model_preprocess_kernel(
124128
base_model_draft_tokens + tid * base_model_draft_tokens_len;
125129
auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
126130
const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
131+
auto* pre_ids_now = pre_ids + tid * pre_ids_len;
127132
#pragma unroll
128133
for (int i = 1; i < base_model_draft_tokens_len; i++) {
129134
base_model_draft_tokens_now[i] = -1;
@@ -137,14 +142,12 @@ __global__ void draft_model_preprocess_kernel(
137142
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
138143
not_stop_flag = 1;
139144
// 1. first token
140-
if (base_model_step_idx_now == 0) {
141-
seq_lens_this_time[tid] = 0;
142-
not_stop_flag = 0;
143-
} else if (seq_lens_encoder[tid] > 0) {
145+
if (seq_lens_encoder[tid] > 0) {
144146
// Can be extended to first few tokens
145147
int seq_len_encoder = seq_lens_encoder[tid];
146148
stop_flags[tid] = false;
147149
int64_t base_model_first_token = accept_tokens_now[0];
150+
pre_ids_now[0] = base_model_first_token;
148151
int position = seq_len_encoder;
149152
if (TRCUNCATE_FIRST_TOKEN) {
150153
input_ids_now[position - 1] = base_model_first_token;
@@ -161,34 +164,17 @@ __global__ void draft_model_preprocess_kernel(
161164
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
162165
} else {
163166
// 2: Last base model generated token and first MTP token
164-
seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2);
165-
step_idx[tid] -= (base_model_seq_len_this_time - 2);
167+
seq_lens_decoder[tid] -= num_model_step - 1;
168+
step_idx[tid] -= num_model_step - 1;
166169
}
167170
for (int i = 0; i < accept_num_now; i++) {
168171
draft_tokens_now[i] = accept_tokens_now[i];
172+
const int pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i);
173+
const int64_t accept_token = accept_tokens_now[i];
174+
pre_ids_now[pre_id_pos] = accept_token;
169175
}
170176
seq_lens_this_time[tid] = accept_num_now;
171177
}
172-
// (liuzichang): Temperary Reserved for debug
173-
// else if (accept_num_now <=
174-
// max_draft_token) /*Accept partial draft tokens*/ {
175-
// // Base Model reject stop
176-
// if (stop_flags[tid]) {
177-
// stop_flags[tid] = false;
178-
// seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
179-
// step_idx[tid] = base_model_step_idx[tid];
180-
// } else {
181-
// seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
182-
// step_idx[tid] -= max_draft_token - accept_num_now;
183-
// }
184-
// int64_t modified_token = accept_tokens_now[accept_num_now - 1];
185-
// draft_tokens_now[0] = modified_token;
186-
// seq_lens_this_time[tid] = 1;
187-
188-
// } else /*Accept all draft tokens*/ {
189-
// draft_tokens_now[1] = accept_tokens_now[max_draft_token];
190-
// seq_lens_this_time[tid] = 2;
191-
// }
192178
} else {
193179
stop_flags[tid] = true;
194180
seq_lens_this_time[tid] = 0;
@@ -215,6 +201,7 @@ void DispatchRunner(
215201
int64_t* step_idx,
216202
bool* not_need_stop,
217203
bool* batch_drop,
204+
int64_t* pre_ids,
218205
const int64_t* accept_tokens,
219206
const int* accept_num,
220207
const int* base_model_seq_lens_this_time,
@@ -225,11 +212,12 @@ void DispatchRunner(
225212
const bool* base_model_is_block_step,
226213
int64_t* base_model_draft_tokens,
227214
const int bsz,
228-
const int max_draft_token,
215+
const int num_model_step,
229216
const int accept_tokens_len,
230217
const int draft_tokens_len,
231218
const int input_ids_len,
232219
const int base_model_draft_tokens_len,
220+
const int pre_ids_len,
233221
const bool splitwise_prefill) {
234222
constexpr int BlockSize = 512;
235223
if (splitwise_prefill) {
@@ -244,6 +232,7 @@ void DispatchRunner(
244232
step_idx,
245233
not_need_stop,
246234
batch_drop,
235+
pre_ids,
247236
accept_tokens,
248237
accept_num,
249238
base_model_seq_lens_this_time,
@@ -254,11 +243,12 @@ void DispatchRunner(
254243
base_model_is_block_step,
255244
base_model_draft_tokens,
256245
bsz,
257-
max_draft_token,
246+
num_model_step,
258247
accept_tokens_len,
259248
draft_tokens_len,
260249
input_ids_len,
261-
base_model_draft_tokens_len);
250+
base_model_draft_tokens_len,
251+
pre_ids_len);
262252
} else {
263253
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
264254
<<<1, BlockSize, 0, stream>>>(
@@ -271,6 +261,7 @@ void DispatchRunner(
271261
step_idx,
272262
not_need_stop,
273263
batch_drop,
264+
pre_ids,
274265
accept_tokens,
275266
accept_num,
276267
base_model_seq_lens_this_time,
@@ -281,11 +272,12 @@ void DispatchRunner(
281272
base_model_is_block_step,
282273
base_model_draft_tokens,
283274
bsz,
284-
max_draft_token,
275+
num_model_step,
285276
accept_tokens_len,
286277
draft_tokens_len,
287278
input_ids_len,
288-
base_model_draft_tokens_len);
279+
base_model_draft_tokens_len,
280+
pre_ids_len);
289281
}
290282
}
291283

@@ -300,6 +292,7 @@ void DispatchTokenMode(
300292
int64_t* step_idx,
301293
bool* not_need_stop,
302294
bool* batch_drop,
295+
int64_t* pre_ids,
303296
const int64_t* accept_tokens,
304297
const int* accept_num,
305298
const int* base_model_seq_lens_this_time,
@@ -310,11 +303,12 @@ void DispatchTokenMode(
310303
const bool* base_model_is_block_step,
311304
int64_t* base_model_draft_tokens,
312305
const int bsz,
313-
const int max_draft_token,
306+
const int num_model_step,
314307
const int accept_tokens_len,
315308
const int draft_tokens_len,
316309
const int input_ids_len,
317310
const int base_model_draft_tokens_len,
311+
const int pre_ids_len,
318312
const bool truncate_first_token,
319313
const bool splitwise_prefill) {
320314
if (truncate_first_token) {
@@ -329,6 +323,7 @@ void DispatchTokenMode(
329323
step_idx,
330324
not_need_stop,
331325
batch_drop,
326+
pre_ids,
332327
accept_tokens,
333328
accept_num,
334329
base_model_seq_lens_this_time,
@@ -339,11 +334,12 @@ void DispatchTokenMode(
339334
base_model_is_block_step,
340335
base_model_draft_tokens,
341336
bsz,
342-
max_draft_token,
337+
num_model_step,
343338
accept_tokens_len,
344339
draft_tokens_len,
345340
input_ids_len,
346341
base_model_draft_tokens_len,
342+
pre_ids_len,
347343
splitwise_prefill
348344
);
349345
} else {
@@ -358,6 +354,7 @@ void DispatchTokenMode(
358354
step_idx,
359355
not_need_stop,
360356
batch_drop,
357+
pre_ids,
361358
accept_tokens,
362359
accept_num,
363360
base_model_seq_lens_this_time,
@@ -368,11 +365,12 @@ void DispatchTokenMode(
368365
base_model_is_block_step,
369366
base_model_draft_tokens,
370367
bsz,
371-
max_draft_token,
368+
num_model_step,
372369
accept_tokens_len,
373370
draft_tokens_len,
374371
input_ids_len,
375372
base_model_draft_tokens_len,
373+
pre_ids_len,
376374
splitwise_prefill
377375
);
378376
}
@@ -390,6 +388,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
390388
const paddle::Tensor& step_idx,
391389
const paddle::Tensor& not_need_stop,
392390
const paddle::Tensor& batch_drop,
391+
const paddle::Tensor& pre_ids,
393392
const paddle::Tensor& accept_tokens,
394393
const paddle::Tensor& accept_num,
395394
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -399,13 +398,14 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
399398
const paddle::Tensor& base_model_stop_flags,
400399
const paddle::Tensor& base_model_is_block_step,
401400
const paddle::Tensor& base_model_draft_tokens,
402-
const int max_draft_token,
401+
const int num_model_step,
403402
const bool truncate_first_token,
404403
const bool splitwise_prefill) {
405404
int real_bsz = seq_lens_this_time.shape()[0];
406405
int accept_tokens_len = accept_tokens.shape()[1];
407406
int input_ids_len = input_ids.shape()[1];
408407
int draft_tokens_len = draft_tokens.shape()[1];
408+
int pre_ids_len = pre_ids.shape()[1];
409409
auto cu_stream = seq_lens_this_time.stream();
410410
constexpr int BlockSize = 512;
411411
int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1];
@@ -423,6 +423,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
423423
const_cast<int64_t*>(step_idx.data<int64_t>()),
424424
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
425425
const_cast<bool*>(batch_drop.data<bool>()),
426+
const_cast<int64_t*>(pre_ids.data<int64_t>()),
426427
accept_tokens.data<int64_t>(),
427428
accept_num.data<int>(),
428429
base_model_seq_lens_this_time.data<int>(),
@@ -433,11 +434,12 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
433434
base_model_is_block_step.data<bool>(),
434435
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
435436
real_bsz,
436-
max_draft_token,
437+
num_model_step,
437438
accept_tokens_len,
438439
draft_tokens_len,
439440
input_ids_len,
440441
base_model_draft_tokens_len,
442+
pre_ids_len,
441443
truncate_first_token,
442444
splitwise_prefill);
443445

@@ -458,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
458460
"step_idx",
459461
"not_need_stop",
460462
"batch_drop",
463+
"pre_ids",
461464
"accept_tokens",
462465
"accept_num",
463466
"base_model_seq_lens_this_time",
@@ -475,8 +478,9 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
475478
"seq_lens_decoder_out",
476479
"step_idx_out",
477480
"not_need_stop_out",
478-
"batch_drop_out"})
479-
.Attrs({"max_draft_token: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
481+
"batch_drop_out",
482+
"pre_ids_out"})
483+
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
480484
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
481485
{"input_ids", "input_ids_out"},
482486
{"stop_flags", "stop_flags_out"},
@@ -485,5 +489,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
485489
{"seq_lens_decoder", "seq_lens_decoder_out"},
486490
{"step_idx", "step_idx_out"},
487491
{"not_need_stop", "not_need_stop_out"},
488-
{"batch_drop", "batch_drop_out"}})
492+
{"batch_drop", "batch_drop_out"},
493+
{"pre_ids", "pre_ids_out"}})
489494
.SetKernelFn(PD_KERNEL(DraftModelPreprocess));

custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
6363
token_this_time = next_tokens_start[seq_len_this_time - 1];
6464
draft_token_now[0] = next_tokens_start[seq_len_this_time - 1];
6565
base_model_draft_tokens_now[substep + 1] = token_this_time;
66-
for (int i = 0; i < seq_len_this_time; ++i) {
67-
pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i];
68-
}
6966
step_idx[tid] += seq_len_this_time;
67+
pre_ids_now[step_idx[tid]] = token_this_time;
68+
7069

7170
} else {
7271
token_this_time = next_tokens_start[0];

custom_ops/gpu_ops/speculate_decoding/draft_model/eagle_get_base_model_hidden_states.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ __global__ void ComputeOrderKernel(
4949
for (int j = 0; j < cur_seq_lens_encoder; j++) {
5050
position_map[in_offset++] = out_offset++;
5151
}
52-
// 2. base model encoder. Base step=0
53-
} else if (cur_base_model_seq_lens_encoder != 0) {
54-
// 3. New end
52+
// 2. Base model stop at last verify-step.
5553
} else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) {
5654
#ifdef DEBUG_EAGLE_KERNEL
5755
printf("batch %d: base=0. draft !=0 \n", i);

0 commit comments

Comments
 (0)