@@ -448,6 +448,14 @@ def get_attr_from_request(request, attr, default_value=None):
448448 request .block_tables , dtype = "int32"
449449 )
450450
451+ if request .get ("bad_words_token_ids" ) is not None :
452+ bad_words_len = len (request .get ("bad_words_token_ids" ))
453+ if bad_words_len > 0 :
454+ self .share_inputs ["bad_tokens_len" ][idx : idx + 1 ] = bad_words_len
455+ self .share_inputs ["bad_tokens" ][idx : idx + 1 , :bad_words_len ] = np .array (
456+ request .get ("bad_words_token_ids" ), dtype = "int64"
457+ )
458+
451459 if request .get ("stop_token_ids" ) is not None and request .get ("stop_seqs_len" ) is not None :
452460 stop_seqs_num = len (request .get ("stop_seqs_len" ))
453461 for i in range (stop_seqs_num , self .model_config .max_stop_seqs_num ):
@@ -567,7 +575,8 @@ def _init_share_inputs(self, max_num_seqs: int):
567575 self .share_inputs ["stop_flags" ] = paddle .full ([max_num_seqs , 1 ], True , dtype = "bool" )
568576 self .share_inputs ["stop_nums" ] = paddle .full ([1 ], max_num_seqs , dtype = "int64" )
569577
570- self .share_inputs ["bad_tokens" ] = paddle .full ([1 ], - 1 , dtype = "int64" )
578+ self .share_inputs ["bad_tokens" ] = paddle .full ([max_num_seqs , self .model_config .vocab_size ], - 1 , dtype = "int64" )
579+ self .share_inputs ["bad_tokens_len" ] = paddle .full ([max_num_seqs ], 1 , dtype = "int64" )
571580 self .share_inputs ["next_tokens" ] = paddle .full ([max_num_seqs , 1 ], - 1 , dtype = "int64" )
572581 self .share_inputs ["is_block_step" ] = paddle .full ([max_num_seqs ], False , dtype = "bool" )
573582 self .share_inputs ["encoder_block_lens" ] = paddle .full ([max_num_seqs ], 0 , dtype = "int32" )
@@ -733,6 +742,9 @@ def _prepare_inputs(self) -> None:
733742 self .share_inputs ["output_cum_offsets" ].copy_ (output_cum_offsets , False )
734743 self .share_inputs ["output_padding_offset" ].copy_ (output_padding_offset , False )
735744
745+ # Update bad tokens len
746+ max_bad_tokens_len = paddle .max (self .share_inputs ["bad_tokens_len" ])
747+
736748 # Initialize forward meta data
737749 self .initialize_forward_meta ()
738750
@@ -750,7 +762,7 @@ def _prepare_inputs(self) -> None:
750762 presence_penalties = self .share_inputs ["presence_score" ],
751763 repetition_penalties = self .share_inputs ["penalty_score" ],
752764 min_dec_lens = self .share_inputs ["min_dec_len" ],
753- bad_words_token_ids = self .share_inputs ["bad_tokens" ],
765+ bad_words_token_ids = self .share_inputs ["bad_tokens" ][:, : max_bad_tokens_len ] ,
754766 eos_token_ids = self .share_inputs ["eos_token_id" ],
755767 max_num_logprobs = 20 if self .enable_logprob else None ,
756768 enable_early_stop = self .enable_early_stop ,
0 commit comments