@@ -26,6 +26,7 @@ __global__ void process_splitwise_prefill(
2626 int64_t * step_idx,
2727 bool * not_need_stop,
2828 bool * batch_drop,
29+ int64_t * pre_ids,
2930 const int64_t * accept_tokens,
3031 const int * accept_num,
3132 const int * base_model_seq_lens_this_time,
@@ -36,11 +37,12 @@ __global__ void process_splitwise_prefill(
3637 const bool * base_model_is_block_step,
3738 int64_t * base_model_draft_tokens,
3839 const int bsz,
39- const int max_draft_token ,
40+ const int num_model_step ,
4041 const int accept_tokens_len,
4142 const int draft_tokens_len,
4243 const int input_ids_len,
43- const int base_model_draft_tokens_len) {
44+ const int base_model_draft_tokens_len,
45+ const int pre_ids_len) {
4446 typedef cub::BlockReduce<int64_t , THREADBLOCK_SIZE> BlockReduce;
4547 __shared__ typename BlockReduce::TempStorage temp_storage;
4648 int64_t not_stop_flag = 0 ;
@@ -93,6 +95,7 @@ __global__ void draft_model_preprocess_kernel(
9395 int64_t * step_idx,
9496 bool * not_need_stop,
9597 bool * batch_drop,
98+ int64_t * pre_ids,
9699 const int64_t * accept_tokens,
97100 const int * accept_num,
98101 const int * base_model_seq_lens_this_time,
@@ -103,11 +106,12 @@ __global__ void draft_model_preprocess_kernel(
103106 const bool * base_model_is_block_step,
104107 int64_t * base_model_draft_tokens,
105108 const int bsz,
106- const int max_draft_token ,
109+ const int num_model_step ,
107110 const int accept_tokens_len,
108111 const int draft_tokens_len,
109112 const int input_ids_len,
110- const int base_model_draft_tokens_len) {
113+ const int base_model_draft_tokens_len,
114+ const int pre_ids_len) {
111115 typedef cub::BlockReduce<int64_t , THREADBLOCK_SIZE> BlockReduce;
112116 __shared__ typename BlockReduce::TempStorage temp_storage;
113117 int64_t not_stop_flag = 0 ;
@@ -124,6 +128,7 @@ __global__ void draft_model_preprocess_kernel(
124128 base_model_draft_tokens + tid * base_model_draft_tokens_len;
125129 auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid];
126130 const int32_t base_model_seq_len_this_time = base_model_seq_lens_this_time[tid];
131+ auto * pre_ids_now = pre_ids + tid * pre_ids_len;
127132#pragma unroll
128133 for (int i = 1 ; i < base_model_draft_tokens_len; i++) {
129134 base_model_draft_tokens_now[i] = -1 ;
@@ -137,14 +142,12 @@ __global__ void draft_model_preprocess_kernel(
137142 if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
138143 not_stop_flag = 1 ;
139144 // 1. first token
140- if (base_model_step_idx_now == 0 ) {
141- seq_lens_this_time[tid] = 0 ;
142- not_stop_flag = 0 ;
143- } else if (seq_lens_encoder[tid] > 0 ) {
145+ if (seq_lens_encoder[tid] > 0 ) {
144146 // Can be extended to first few tokens
145147 int seq_len_encoder = seq_lens_encoder[tid];
146148 stop_flags[tid] = false ;
147149 int64_t base_model_first_token = accept_tokens_now[0 ];
150+ pre_ids_now[0 ] = base_model_first_token;
148151 int position = seq_len_encoder;
149152 if (TRCUNCATE_FIRST_TOKEN) {
150153 input_ids_now[position - 1 ] = base_model_first_token;
@@ -161,34 +164,17 @@ __global__ void draft_model_preprocess_kernel(
161164 step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time;
162165 } else {
163166 // 2: Last base model generated token and first MTP token
164- seq_lens_decoder[tid] -= (base_model_seq_len_this_time - 2 ) ;
165- step_idx[tid] -= (base_model_seq_len_this_time - 2 ) ;
167+ seq_lens_decoder[tid] -= num_model_step - 1 ;
168+ step_idx[tid] -= num_model_step - 1 ;
166169 }
167170 for (int i = 0 ; i < accept_num_now; i++) {
168171 draft_tokens_now[i] = accept_tokens_now[i];
172+ const int pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i);
173+ const int64_t accept_token = accept_tokens_now[i];
174+ pre_ids_now[pre_id_pos] = accept_token;
169175 }
170176 seq_lens_this_time[tid] = accept_num_now;
171177 }
172- // (liuzichang): Temperary Reserved for debug
173- // else if (accept_num_now <=
174- // max_draft_token) /*Accept partial draft tokens*/ {
175- // // Base Model reject stop
176- // if (stop_flags[tid]) {
177- // stop_flags[tid] = false;
178- // seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid];
179- // step_idx[tid] = base_model_step_idx[tid];
180- // } else {
181- // seq_lens_decoder[tid] -= max_draft_token - accept_num_now;
182- // step_idx[tid] -= max_draft_token - accept_num_now;
183- // }
184- // int64_t modified_token = accept_tokens_now[accept_num_now - 1];
185- // draft_tokens_now[0] = modified_token;
186- // seq_lens_this_time[tid] = 1;
187-
188- // } else /*Accept all draft tokens*/ {
189- // draft_tokens_now[1] = accept_tokens_now[max_draft_token];
190- // seq_lens_this_time[tid] = 2;
191- // }
192178 } else {
193179 stop_flags[tid] = true ;
194180 seq_lens_this_time[tid] = 0 ;
@@ -215,6 +201,7 @@ void DispatchRunner(
215201 int64_t * step_idx,
216202 bool * not_need_stop,
217203 bool * batch_drop,
204+ int64_t * pre_ids,
218205 const int64_t * accept_tokens,
219206 const int * accept_num,
220207 const int * base_model_seq_lens_this_time,
@@ -225,11 +212,12 @@ void DispatchRunner(
225212 const bool * base_model_is_block_step,
226213 int64_t * base_model_draft_tokens,
227214 const int bsz,
228- const int max_draft_token ,
215+ const int num_model_step ,
229216 const int accept_tokens_len,
230217 const int draft_tokens_len,
231218 const int input_ids_len,
232219 const int base_model_draft_tokens_len,
220+ const int pre_ids_len,
233221 const bool splitwise_prefill) {
234222 constexpr int BlockSize = 512 ;
235223 if (splitwise_prefill) {
@@ -244,6 +232,7 @@ void DispatchRunner(
244232 step_idx,
245233 not_need_stop,
246234 batch_drop,
235+ pre_ids,
247236 accept_tokens,
248237 accept_num,
249238 base_model_seq_lens_this_time,
@@ -254,11 +243,12 @@ void DispatchRunner(
254243 base_model_is_block_step,
255244 base_model_draft_tokens,
256245 bsz,
257- max_draft_token ,
246+ num_model_step ,
258247 accept_tokens_len,
259248 draft_tokens_len,
260249 input_ids_len,
261- base_model_draft_tokens_len);
250+ base_model_draft_tokens_len,
251+ pre_ids_len);
262252 } else {
263253 draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
264254 <<<1 , BlockSize, 0 , stream>>> (
@@ -271,6 +261,7 @@ void DispatchRunner(
271261 step_idx,
272262 not_need_stop,
273263 batch_drop,
264+ pre_ids,
274265 accept_tokens,
275266 accept_num,
276267 base_model_seq_lens_this_time,
@@ -281,11 +272,12 @@ void DispatchRunner(
281272 base_model_is_block_step,
282273 base_model_draft_tokens,
283274 bsz,
284- max_draft_token ,
275+ num_model_step ,
285276 accept_tokens_len,
286277 draft_tokens_len,
287278 input_ids_len,
288- base_model_draft_tokens_len);
279+ base_model_draft_tokens_len,
280+ pre_ids_len);
289281 }
290282}
291283
@@ -300,6 +292,7 @@ void DispatchTokenMode(
300292 int64_t * step_idx,
301293 bool * not_need_stop,
302294 bool * batch_drop,
295+ int64_t * pre_ids,
303296 const int64_t * accept_tokens,
304297 const int * accept_num,
305298 const int * base_model_seq_lens_this_time,
@@ -310,11 +303,12 @@ void DispatchTokenMode(
310303 const bool * base_model_is_block_step,
311304 int64_t * base_model_draft_tokens,
312305 const int bsz,
313- const int max_draft_token ,
306+ const int num_model_step ,
314307 const int accept_tokens_len,
315308 const int draft_tokens_len,
316309 const int input_ids_len,
317310 const int base_model_draft_tokens_len,
311+ const int pre_ids_len,
318312 const bool truncate_first_token,
319313 const bool splitwise_prefill) {
320314 if (truncate_first_token) {
@@ -329,6 +323,7 @@ void DispatchTokenMode(
329323 step_idx,
330324 not_need_stop,
331325 batch_drop,
326+ pre_ids,
332327 accept_tokens,
333328 accept_num,
334329 base_model_seq_lens_this_time,
@@ -339,11 +334,12 @@ void DispatchTokenMode(
339334 base_model_is_block_step,
340335 base_model_draft_tokens,
341336 bsz,
342- max_draft_token ,
337+ num_model_step ,
343338 accept_tokens_len,
344339 draft_tokens_len,
345340 input_ids_len,
346341 base_model_draft_tokens_len,
342+ pre_ids_len,
347343 splitwise_prefill
348344 );
349345 } else {
@@ -358,6 +354,7 @@ void DispatchTokenMode(
358354 step_idx,
359355 not_need_stop,
360356 batch_drop,
357+ pre_ids,
361358 accept_tokens,
362359 accept_num,
363360 base_model_seq_lens_this_time,
@@ -368,11 +365,12 @@ void DispatchTokenMode(
368365 base_model_is_block_step,
369366 base_model_draft_tokens,
370367 bsz,
371- max_draft_token ,
368+ num_model_step ,
372369 accept_tokens_len,
373370 draft_tokens_len,
374371 input_ids_len,
375372 base_model_draft_tokens_len,
373+ pre_ids_len,
376374 splitwise_prefill
377375 );
378376 }
@@ -390,6 +388,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
390388 const paddle::Tensor& step_idx,
391389 const paddle::Tensor& not_need_stop,
392390 const paddle::Tensor& batch_drop,
391+ const paddle::Tensor& pre_ids,
393392 const paddle::Tensor& accept_tokens,
394393 const paddle::Tensor& accept_num,
395394 const paddle::Tensor& base_model_seq_lens_this_time,
@@ -399,13 +398,14 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
399398 const paddle::Tensor& base_model_stop_flags,
400399 const paddle::Tensor& base_model_is_block_step,
401400 const paddle::Tensor& base_model_draft_tokens,
402- const int max_draft_token ,
401+ const int num_model_step ,
403402 const bool truncate_first_token,
404403 const bool splitwise_prefill) {
405404 int real_bsz = seq_lens_this_time.shape ()[0 ];
406405 int accept_tokens_len = accept_tokens.shape ()[1 ];
407406 int input_ids_len = input_ids.shape ()[1 ];
408407 int draft_tokens_len = draft_tokens.shape ()[1 ];
408+ int pre_ids_len = pre_ids.shape ()[1 ];
409409 auto cu_stream = seq_lens_this_time.stream ();
410410 constexpr int BlockSize = 512 ;
411411 int base_model_draft_tokens_len = base_model_draft_tokens.shape ()[1 ];
@@ -423,6 +423,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
423423 const_cast <int64_t *>(step_idx.data <int64_t >()),
424424 const_cast <bool *>(not_need_stop_gpu.data <bool >()),
425425 const_cast <bool *>(batch_drop.data <bool >()),
426+ const_cast <int64_t *>(pre_ids.data <int64_t >()),
426427 accept_tokens.data <int64_t >(),
427428 accept_num.data <int >(),
428429 base_model_seq_lens_this_time.data <int >(),
@@ -433,11 +434,12 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
433434 base_model_is_block_step.data <bool >(),
434435 const_cast <int64_t *>(base_model_draft_tokens.data <int64_t >()),
435436 real_bsz,
436- max_draft_token ,
437+ num_model_step ,
437438 accept_tokens_len,
438439 draft_tokens_len,
439440 input_ids_len,
440441 base_model_draft_tokens_len,
442+ pre_ids_len,
441443 truncate_first_token,
442444 splitwise_prefill);
443445
@@ -458,6 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
458460 " step_idx" ,
459461 " not_need_stop" ,
460462 " batch_drop" ,
463+ " pre_ids" ,
461464 " accept_tokens" ,
462465 " accept_num" ,
463466 " base_model_seq_lens_this_time" ,
@@ -475,8 +478,9 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
475478 " seq_lens_decoder_out" ,
476479 " step_idx_out" ,
477480 " not_need_stop_out" ,
478- " batch_drop_out" })
479- .Attrs({" max_draft_token: int" , " truncate_first_token: bool" , " splitwise_prefill: bool" })
481+ " batch_drop_out" ,
482+ " pre_ids_out" })
483+ .Attrs({" num_model_step: int" , " truncate_first_token: bool" , " splitwise_prefill: bool" })
480484 .SetInplaceMap({{" draft_tokens" , " draft_tokens_out" },
481485 {" input_ids" , " input_ids_out" },
482486 {" stop_flags" , " stop_flags_out" },
@@ -485,5 +489,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
485489 {" seq_lens_decoder" , " seq_lens_decoder_out" },
486490 {" step_idx" , " step_idx_out" },
487491 {" not_need_stop" , " not_need_stop_out" },
488- {" batch_drop" , " batch_drop_out" }})
492+ {" batch_drop" , " batch_drop_out" },
493+ {" pre_ids" , " pre_ids_out" }})
489494 .SetKernelFn(PD_KERNEL(DraftModelPreprocess));
0 commit comments