Skip to content

Commit 0e0eaa1

Browse files
authored
[BugFix] fix mm revert bug (PaddlePaddle#6061)
* fix mm revert bug * update code
1 parent 70a962d commit 0e0eaa1

2 files changed

Lines changed: 27 additions & 1 deletion

File tree

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,15 @@ def revert_chunked_mm_input(self, mm_inputs, matched_token_num):
402402
position.offset // self.config.cache_config.block_size
403403
) * self.config.cache_config.block_size
404404
position_idx -= 1
405-
elif matched_token_num < position.offset:
405+
elif matched_token_num <= position.offset:
406406
position_idx -= 1
407407
elif matched_token_num >= position.offset + position.length:
408408
break
409+
else:
410+
llm_logger.error(
411+
f"revert_chunked_mm_input error, matched_token_num:{matched_token_num} position:{position}, {mm_inputs['mm_positions']}"
412+
)
413+
break
409414
return matched_token_num
410415

411416
def _get_num_new_tokens(self, request, token_budget):
@@ -454,6 +459,18 @@ def _compute_audio_prefix_count(end_idx, end_patch_idx):
454459
start_patch_idx = inputs["patch_idx"][-1]
455460
else:
456461
start_patch_idx = inputs["patch_idx"][pre_end_idx]
462+
if (
463+
pre_end_idx > 0
464+
and request.prompt_token_ids[pre_end_idx]
465+
in [
466+
inputs["image_patch_id"],
467+
inputs["video_patch_id"],
468+
inputs["audio_patch_id"],
469+
]
470+
and request.prompt_token_ids[pre_end_idx] != request.prompt_token_ids[pre_end_idx - 1]
471+
):
472+
# It just hit the starting position of the image / video / audio
473+
start_patch_idx -= 1
457474
start_patch_map = inputs["patch_map"][start_patch_idx]
458475
request.image_start = start_patch_map["image_num"]
459476
request.video_start = start_patch_map["video_num"]

tests/v1/test_resource_manager_v1.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,15 @@ def test_revert_chunked_mm_input_after_last_chunk(self):
284284
result = self.manager.revert_chunked_mm_input(mm_inputs, 256)
285285
self.assertEqual(result, 256)
286286

287+
def test_revert_chunked_mm_input_match_image_offset(self):
288+
mm_inputs = {
289+
"mm_positions": [
290+
ImagePosition(offset=64, length=21),
291+
]
292+
}
293+
result = self.manager.revert_chunked_mm_input(mm_inputs, 64)
294+
self.assertEqual(result, 64)
295+
287296

288297
if __name__ == "__main__":
289298
unittest.main()

0 commit comments

Comments
 (0)