Skip to content

Commit 31c219d

Browse files
authored
[Graph Optimization] Add max_capture_shape_prefill && cudagraph_capture_sizes_prefill (PaddlePaddle#6148)
* Add max_capture_shape_dy2st parameter to YAML config * split cudagraph capture size between decode and prefill * rm if * add default value
1 parent 8d27a52 commit 31c219d

1 file changed

Lines changed: 40 additions & 7 deletions

File tree

fastdeploy/config.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def __init__(
905905
- None (default): capture sizes are inferred from llm config.
906906
- list[int]: capture sizes are specified as given."""
907907
self.cudagraph_capture_sizes: Optional[list[int]] = None
908+
self.cudagraph_capture_sizes_prefill: list[int] = [1, 2, 4, 8]
908909
""" Number of warmup runs for cudagraph. """
909910
self.cudagraph_num_of_warmups: int = 2
910911
"""Whether to copy input tensors for cudagraph.
@@ -942,7 +943,7 @@ def __init__(
942943
""" Maximum CUDA Graph capture size for static graph mode.
943944
Recommend 512 for small models (e.g., ERNIE45T 0.3B) and 128 for massive models (e.g., 300B).
944945
"""
945-
self.max_capture_shape_dy2st: int = 512
946+
self.max_capture_shape_prefill: int = 512
946947

947948
# CINN Config ...
948949
if args is not None:
@@ -952,13 +953,16 @@ def __init__(
952953

953954
self.check_legality_parameters()
954955

955-
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
956+
def init_with_cudagrpah_size(self, max_capture_size: int = 0, max_capture_shape_prefill: int = 0) -> None:
956957
"""
957958
Initialize cuda graph capture sizes and
958959
pre-compute the mapping from batch size to padded graph size
959960
"""
960961
# Regular capture sizes
961962
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
963+
self.cudagraph_capture_sizes_prefill = [
964+
size for size in self.cudagraph_capture_sizes_prefill if size <= max_capture_shape_prefill
965+
]
962966
dedup_sizes = list(set(self.cudagraph_capture_sizes))
963967
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
964968
logger.info(
@@ -970,7 +974,11 @@ def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
970974

971975
# Sort to make sure cudagraph capture sizes are in descending order
972976
self.cudagraph_capture_sizes.sort(reverse=True)
977+
self.cudagraph_capture_sizes_prefill.sort(reverse=True)
973978
self.max_capture_size = self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
979+
self.max_capture_size_prefill = (
980+
self.cudagraph_capture_sizes_prefill[0] if self.cudagraph_capture_sizes_prefill else 0
981+
)
974982

975983
# Pre-compute the mapping from shape to padded graph size
976984
self.real_shape_to_captured_size = {}
@@ -982,7 +990,21 @@ def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
982990
self.real_shape_to_captured_size[bs] = end
983991
self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size
984992

985-
def _set_cudagraph_sizes(self, max_capture_size: int = 0, dec_token_per_query_per_step: int = 1):
993+
self.real_shape_to_captured_size_prefill = {}
994+
for end, start in zip(self.cudagraph_capture_sizes_prefill, self.cudagraph_capture_sizes_prefill[1:] + [0]):
995+
for bs in range(start, end):
996+
if bs == start:
997+
self.real_shape_to_captured_size_prefill[bs] = start
998+
else:
999+
self.real_shape_to_captured_size_prefill[bs] = end
1000+
self.real_shape_to_captured_size_prefill[self.max_capture_size_prefill] = self.max_capture_size_prefill
1001+
1002+
def _set_cudagraph_sizes(
1003+
self,
1004+
max_capture_size: int = 0,
1005+
max_capture_shape_prefill: int = 0,
1006+
dec_token_per_query_per_step: int = 1,
1007+
):
9861008
"""
9871009
Calculate a series of candidate capture sizes,
9881010
and then extract a portion of them as the capture list for the CUDA graph based on user input.
@@ -996,14 +1018,21 @@ def _set_cudagraph_sizes(self, max_capture_size: int = 0, dec_token_per_query_pe
9961018
# Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step
9971019
draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range(9, 33)]
9981020

1021+
draft_capture_sizes_prefill = draft_capture_sizes.copy()
9991022
draft_capture_sizes.append(max_capture_size)
10001023
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
10011024

1025+
draft_capture_sizes_prefill.append(max_capture_shape_prefill)
1026+
self.cudagraph_capture_sizes_prefill = sorted(draft_capture_sizes_prefill)
1027+
10021028
def filter_capture_size(self, tp_size: int = 1):
10031029
"""When TSP is used, capture size must be divisible by tp size."""
10041030
self.cudagraph_capture_sizes = [
10051031
draft_size for draft_size in self.cudagraph_capture_sizes if (draft_size % tp_size == 0)
10061032
]
1033+
self.cudagraph_capture_sizes_prefill = [
1034+
draft_size for draft_size in self.cudagraph_capture_sizes_prefill if (draft_size % tp_size == 0)
1035+
]
10071036

10081037
def to_json_string(self):
10091038
"""
@@ -1672,8 +1701,7 @@ def __init__(
16721701
else:
16731702
max_capture_shape = min(512, max_capture_shape)
16741703

1675-
if self.graph_opt_config.graph_opt_level > 0:
1676-
max_capture_shape = graph_opt_config.max_capture_shape_dy2st
1704+
max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill
16771705

16781706
if self.graph_opt_config.cudagraph_capture_sizes is None:
16791707
dec_token_per_query_per_step = (
@@ -1682,9 +1710,14 @@ def __init__(
16821710
else 1
16831711
)
16841712
self.graph_opt_config._set_cudagraph_sizes(
1685-
max_capture_size=max_capture_shape, dec_token_per_query_per_step=dec_token_per_query_per_step
1713+
max_capture_size=max_capture_shape,
1714+
max_capture_shape_prefill=max_capture_shape_prefill,
1715+
dec_token_per_query_per_step=dec_token_per_query_per_step,
16861716
)
1687-
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape)
1717+
self.graph_opt_config.init_with_cudagrpah_size(
1718+
max_capture_size=max_capture_shape,
1719+
max_capture_shape_prefill=max_capture_shape_prefill,
1720+
)
16881721

16891722
self.tokenizer = tokenizer
16901723
self.ips = ips

0 commit comments

Comments
 (0)