Skip to content

Commit 74aa31d

Browse files
[Feature] support bad_words (PaddlePaddle#3055)
* support bad_words * support online infer bad_words * update * add CI test * update * update * update --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
1 parent 9c96234 commit 74aa31d

10 files changed

Lines changed: 263 additions & 15 deletions

File tree

custom_ops/gpu_ops/speculate_decoding/speculate_token_penalty_multi_scores.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ void token_penalty_multi_scores_kernel(
180180
int64_t token_num = shape[0];
181181
int64_t length = shape[1];
182182
int64_t length_id = pre_ids.shape()[1];
183-
int64_t length_bad_words = bad_tokens.shape()[0];
183+
int64_t length_bad_words = bad_tokens.shape()[1];
184184

185185
int64_t end_length = eos_token_id.shape()[0];
186186

custom_ops/gpu_ops/token_penalty_multi_scores.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
171171

172172
int64_t vocab_size = shape[1];
173173
int64_t max_dec_len = pre_ids.shape()[1];
174-
int64_t bad_words_len = bad_tokens.shape()[0];
174+
int64_t bad_words_len = bad_tokens.shape()[1];
175175
int64_t eos_len = eos_token_id.shape()[0];
176176
int64_t max_model_len = prompt_ids.shape()[1];
177177

fastdeploy/engine/engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def add_requests(self, task, sampling_params=None, **kwargs):
491491
request = Request.from_dict(task)
492492
llm_logger.info(f"Receive request {request}")
493493
if sampling_params is not None:
494+
sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
494495
request.sampling_params = sampling_params
495496
request.preprocess_start_time = time.time()
496497

@@ -747,6 +748,8 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
747748
"""
748749
for task in tasks:
749750
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
751+
if task.sampling_params.bad_words is not None:
752+
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
750753
# TODO 返回至 scheduler
751754
if allocated:
752755
current_tasks = []

fastdeploy/engine/sampling_params.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from dataclasses import dataclass, fields
2121
from typing import Any, List, Optional, Union
2222

23+
from fastdeploy.utils import llm_logger as logger
24+
2325

2426
@dataclass
2527
class SamplingParams:
@@ -97,6 +99,7 @@ class SamplingParams:
9799
min_tokens: int = 1
98100
logprobs: Optional[int] = None
99101
bad_words: Optional[List[str]] = None
102+
_bad_words_token_ids: Optional[List[int]] = None
100103

101104
@classmethod
102105
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -201,11 +204,42 @@ def _verify_args(self) -> None:
201204
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
202205

203206
def update_from_tokenizer(self, tokenizer):
204-
"""
205-
# TODO: Implement stop tokens and bad words support
206-
# Currently stop tokens and bad words are not supported yet
207-
"""
208-
pass
207+
"""Support bad words"""
208+
if self.bad_words is None:
209+
return
210+
self._bad_words_token_ids = []
211+
for bad_word in self.bad_words:
212+
# To prohibit words both at the beginning
213+
# and in the middle of text
214+
# (related to add_prefix_space tokenizer parameter)
215+
for add_prefix_space in [False, True]:
216+
prefix = " " if add_prefix_space else ""
217+
prompt = prefix + bad_word.lstrip()
218+
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)["input_ids"]
219+
220+
if len(prompt_token_ids) != 1:
221+
logger.warning(
222+
f"Skip bad_words: {prompt}."
223+
f"Bad words should be a single token."
224+
f"Got tokens: {prompt_token_ids}."
225+
)
226+
continue
227+
228+
if prompt_token_ids[0] > tokenizer.vocab_size:
229+
logger.warning(
230+
f"Skip bad_words: {prompt}."
231+
f"All token id values should be satisfying:"
232+
f" 0 <= token_id < {tokenizer.vocab_size}."
233+
f"Got token: {prompt_token_ids}."
234+
)
235+
continue
236+
237+
if prompt_token_ids not in self._bad_words_token_ids:
238+
self._bad_words_token_ids.extend(prompt_token_ids)
239+
240+
@property
241+
def bad_words_token_ids(self) -> Optional[List[list[int]]]:
242+
return self._bad_words_token_ids
209243

210244

211245
@dataclass

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ class CompletionRequest(BaseModel):
349349
extra_body: Optional[dict] = None
350350
return_token_ids: Optional[bool] = False
351351
prompt_token_ids: Optional[List[int]] = None
352+
bad_words: Optional[List[str]] = None
352353

353354
response_format: Optional[AnyResponseFormat] = None
354355
guided_json: Optional[Union[str, dict, BaseModel]] = None
@@ -484,6 +485,7 @@ class ChatCompletionRequest(BaseModel):
484485
return_token_ids: Optional[bool] = False
485486
prompt_token_ids: Optional[List[int]] = None
486487
disable_chat_template: Optional[bool] = False
488+
bad_words: Optional[List[str]] = None
487489

488490
response_format: Optional[AnyResponseFormat] = None
489491
guided_json: Optional[Union[str, dict, BaseModel]] = None

fastdeploy/worker/gcu_model_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,14 @@ def get_attr_from_request(request, attr, default_value=None):
270270
request.block_tables, dtype="int32"
271271
)
272272

273+
if request.get("bad_words_token_ids") is not None:
274+
bad_words_len = len(request.get("bad_words_token_ids"))
275+
if bad_words_len > 0:
276+
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
277+
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
278+
request.get("bad_words_token_ids"), dtype="int64"
279+
)
280+
273281
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
274282
stop_seqs_num = len(request.get("stop_seqs_len"))
275283
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -382,7 +390,8 @@ def _init_share_inputs(self, max_num_seqs: int):
382390
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
383391
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
384392

385-
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
393+
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
394+
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
386395
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
387396
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
388397
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -511,6 +520,9 @@ def _prepare_inputs(self) -> None:
511520
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
512521
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
513522

523+
# Update bad tokens len
524+
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
525+
514526
# Initialize forward meta data
515527
self.initialize_forward_meta()
516528

@@ -528,7 +540,7 @@ def _prepare_inputs(self) -> None:
528540
presence_penalties=self.share_inputs["presence_score"],
529541
repetition_penalties=self.share_inputs["penalty_score"],
530542
min_dec_lens=self.share_inputs["min_dec_len"],
531-
bad_words_token_ids=self.share_inputs["bad_tokens"],
543+
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
532544
eos_token_ids=self.share_inputs["eos_token_id"],
533545
max_num_logprobs=20 if self.enable_logprob else None,
534546
)

fastdeploy/worker/gpu_model_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,14 @@ def get_attr_from_request(request, attr, default_value=None):
448448
request.block_tables, dtype="int32"
449449
)
450450

451+
if request.get("bad_words_token_ids") is not None:
452+
bad_words_len = len(request.get("bad_words_token_ids"))
453+
if bad_words_len > 0:
454+
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
455+
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
456+
request.get("bad_words_token_ids"), dtype="int64"
457+
)
458+
451459
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
452460
stop_seqs_num = len(request.get("stop_seqs_len"))
453461
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -567,7 +575,8 @@ def _init_share_inputs(self, max_num_seqs: int):
567575
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
568576
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
569577

570-
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
578+
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
579+
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
571580
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
572581
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
573582
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -733,6 +742,9 @@ def _prepare_inputs(self) -> None:
733742
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
734743
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
735744

745+
# Update bad tokens len
746+
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
747+
736748
# Initialize forward meta data
737749
self.initialize_forward_meta()
738750

@@ -750,7 +762,7 @@ def _prepare_inputs(self) -> None:
750762
presence_penalties=self.share_inputs["presence_score"],
751763
repetition_penalties=self.share_inputs["penalty_score"],
752764
min_dec_lens=self.share_inputs["min_dec_len"],
753-
bad_words_token_ids=self.share_inputs["bad_tokens"],
765+
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
754766
eos_token_ids=self.share_inputs["eos_token_id"],
755767
max_num_logprobs=20 if self.enable_logprob else None,
756768
enable_early_stop=self.enable_early_stop,

fastdeploy/worker/iluvatar_model_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
242242
request.block_tables, dtype="int32"
243243
)
244244

245+
if request.get("bad_words_token_ids") is not None:
246+
bad_words_len = len(request.get("bad_words_token_ids"))
247+
if bad_words_len > 0:
248+
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
249+
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
250+
request.get("bad_words_token_ids"), dtype="int64"
251+
)
252+
245253
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
246254
stop_seqs_num = len(request.get("stop_seqs_len"))
247255
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -347,7 +355,8 @@ def _init_share_inputs(self, max_num_seqs: int):
347355
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
348356
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
349357

350-
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
358+
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
359+
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
351360
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
352361
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
353362
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -484,6 +493,9 @@ def _prepare_inputs(self) -> None:
484493
self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
485494
self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
486495

496+
# Update bad tokens len
497+
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
498+
487499
# Initialize forward meta data
488500
self.initialize_forward_meta()
489501

@@ -500,7 +512,7 @@ def _prepare_inputs(self) -> None:
500512
presence_penalties=self.share_inputs["presence_score"],
501513
repetition_penalties=self.share_inputs["penalty_score"],
502514
min_dec_lens=self.share_inputs["min_dec_len"],
503-
bad_words_token_ids=self.share_inputs["bad_tokens"],
515+
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
504516
eos_token_ids=self.share_inputs["eos_token_id"],
505517
)
506518

fastdeploy/worker/xpu_model_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,14 @@ def process_prefill_inputs(self, req_dicts: List[Request]):
506506
request.block_tables, dtype="int32"
507507
)
508508

509+
if request.get("bad_words_token_ids") is not None:
510+
bad_words_len = len(request.get("bad_words_token_ids"))
511+
if bad_words_len > 0:
512+
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
513+
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
514+
request.get("bad_words_token_ids"), dtype="int64"
515+
)
516+
509517
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
510518
stop_seqs_num = len(request.get("stop_seqs_len"))
511519
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
@@ -574,7 +582,8 @@ def _init_share_inputs(self, max_num_seqs: int):
574582
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
575583
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
576584

577-
self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64")
585+
self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
586+
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
578587
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
579588
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
580589
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
@@ -652,6 +661,9 @@ def _prepare_inputs(self, is_dummy_run=False) -> None:
652661
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
653662
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
654663
)
664+
# Update bad tokens len
665+
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
666+
655667
self.forward_meta.attn_backend = self.attn_backends[0]
656668
self.initialize_attention_backend()
657669

@@ -667,7 +679,7 @@ def _prepare_inputs(self, is_dummy_run=False) -> None:
667679
presence_penalties=self.share_inputs["presence_score"],
668680
repetition_penalties=self.share_inputs["penalty_score"],
669681
min_dec_lens=self.share_inputs["min_dec_len"],
670-
bad_words_token_ids=self.share_inputs["bad_tokens"],
682+
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
671683
eos_token_ids=self.share_inputs["eos_token_id"],
672684
)
673685

0 commit comments

Comments
 (0)