Skip to content

Commit 7c5e34e

Browse files
authored
[FIX]fix rejection sampling when topp=0 using _SAMPLING_EPS (PaddlePaddle#2967)
* fix rejection sampling when topp=0 * fix
1 parent dbe6225 commit 7c5e34e

3 files changed

Lines changed: 9 additions & 1 deletion

File tree

custom_ops/gpu_ops/sample_kernels/sampling.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
292292
curand_init(philox_seed, bx, philox_offset, &state);
293293
const uint32_t row_idx = bx;
294294
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
295-
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
295+
const float p = top_p_arr[row_idx];
296296

297297
extern __shared__ __align__(
298298
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))

fastdeploy/input/ernie_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def process_request(self, request, max_model_len=None, **kwargs):
123123
if request.get("temperature") < _SAMPLING_EPS:
124124
# zero temperature is equivalent to greedy sampling
125125
request.set("temperature", 1)
126+
if request.get("top_p") < _SAMPLING_EPS:
127+
request.set("top_p", _SAMPLING_EPS)
126128
data_processor_logger.info(f"Processed request {request}")
127129
return request
128130

@@ -174,6 +176,8 @@ def process_request_dict(self, request, max_model_len=None):
174176
if request.get("temperature") < _SAMPLING_EPS:
175177
# zero temperature is equivalent to greedy sampling
176178
request["temperature"] = 1
179+
if request.get("top_p") < _SAMPLING_EPS:
180+
request["top_p"] = _SAMPLING_EPS
177181
data_processor_logger.info(f"Processed request {request}")
178182

179183
return request

fastdeploy/input/text_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def process_request(self, request, max_model_len=None, **kwargs):
252252
if request.get("temperature") < _SAMPLING_EPS:
253253
# zero temperature is equivalent to greedy sampling
254254
request.set("temperature", 1)
255+
if request.get("top_p") < _SAMPLING_EPS:
256+
request.set("top_p", _SAMPLING_EPS)
255257
data_processor_logger.info(f"Processed request {request}")
256258
return request
257259

@@ -297,6 +299,8 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
297299
if request.get("temperature") < _SAMPLING_EPS:
298300
# zero temperature is equivalent to greedy sampling
299301
request["temperature"] = 1
302+
if request.get("top_p") < _SAMPLING_EPS:
303+
request["top_p"] = _SAMPLING_EPS
300304
data_processor_logger.info(f"Processed request {request}")
301305
return request
302306

0 commit comments

Comments
 (0)