@@ -905,6 +905,7 @@ def __init__(
905905 - None (default): capture sizes are inferred from llm config.
906906 - list[int]: capture sizes are specified as given."""
907907 self .cudagraph_capture_sizes : Optional [list [int ]] = None
908+ self .cudagraph_capture_sizes_prefill : list [int ] = [1 , 2 , 4 , 8 ]
908909 """ Number of warmup runs for cudagraph. """
909910 self .cudagraph_num_of_warmups : int = 2
910911 """Whether to copy input tensors for cudagraph.
@@ -942,7 +943,7 @@ def __init__(
942943 """ Maximum CUDA Graph capture size for static graph mode.
943944 Recommend 512 for small models (e.g., ERNIE45T 0.3B) and 128 for massive models (e.g., 300B).
944945 """
945- self .max_capture_shape_dy2st : int = 512
946+ self .max_capture_shape_prefill : int = 512
946947
947948 # CINN Config ...
948949 if args is not None :
@@ -952,13 +953,16 @@ def __init__(
952953
953954 self .check_legality_parameters ()
954955
955- def init_with_cudagrpah_size (self , max_capture_size : int = 0 ) -> None :
956+ def init_with_cudagrpah_size (self , max_capture_size : int = 0 , max_capture_shape_prefill : int = 0 ) -> None :
956957 """
957958 Initialize cuda graph capture sizes and
958959 pre-compute the mapping from batch size to padded graph size
959960 """
960961 # Regular capture sizes
961962 self .cudagraph_capture_sizes = [size for size in self .cudagraph_capture_sizes if size <= max_capture_size ]
963+ self .cudagraph_capture_sizes_prefill = [
964+ size for size in self .cudagraph_capture_sizes_prefill if size <= max_capture_shape_prefill
965+ ]
962966 dedup_sizes = list (set (self .cudagraph_capture_sizes ))
963967 if len (dedup_sizes ) < len (self .cudagraph_capture_sizes ):
964968 logger .info (
@@ -970,7 +974,11 @@ def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
970974
971975 # Sort to make sure cudagraph capture sizes are in descending order
972976 self .cudagraph_capture_sizes .sort (reverse = True )
977+ self .cudagraph_capture_sizes_prefill .sort (reverse = True )
973978 self .max_capture_size = self .cudagraph_capture_sizes [0 ] if self .cudagraph_capture_sizes else 0
979+ self .max_capture_size_prefill = (
980+ self .cudagraph_capture_sizes_prefill [0 ] if self .cudagraph_capture_sizes_prefill else 0
981+ )
974982
975983 # Pre-compute the mapping from shape to padded graph size
976984 self .real_shape_to_captured_size = {}
@@ -982,7 +990,21 @@ def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
982990 self .real_shape_to_captured_size [bs ] = end
983991 self .real_shape_to_captured_size [self .max_capture_size ] = self .max_capture_size
984992
985- def _set_cudagraph_sizes (self , max_capture_size : int = 0 , dec_token_per_query_per_step : int = 1 ):
993+ self .real_shape_to_captured_size_prefill = {}
994+ for end , start in zip (self .cudagraph_capture_sizes_prefill , self .cudagraph_capture_sizes_prefill [1 :] + [0 ]):
995+ for bs in range (start , end ):
996+ if bs == start :
997+ self .real_shape_to_captured_size_prefill [bs ] = start
998+ else :
999+ self .real_shape_to_captured_size_prefill [bs ] = end
1000+ self .real_shape_to_captured_size_prefill [self .max_capture_size_prefill ] = self .max_capture_size_prefill
1001+
1002+ def _set_cudagraph_sizes (
1003+ self ,
1004+ max_capture_size : int = 0 ,
1005+ max_capture_shape_prefill : int = 0 ,
1006+ dec_token_per_query_per_step : int = 1 ,
1007+ ):
9861008 """
9871009 Calculate a series of candidate capture sizes,
9881010 and then extract a portion of them as the capture list for the CUDA graph based on user input.
@@ -996,14 +1018,21 @@ def _set_cudagraph_sizes(self, max_capture_size: int = 0, dec_token_per_query_pe
9961018 # Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step
9971019 draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range (9 , 33 )]
9981020
1021+ draft_capture_sizes_prefill = draft_capture_sizes .copy ()
9991022 draft_capture_sizes .append (max_capture_size )
10001023 self .cudagraph_capture_sizes = sorted (draft_capture_sizes )
10011024
1025+ draft_capture_sizes_prefill .append (max_capture_shape_prefill )
1026+ self .cudagraph_capture_sizes_prefill = sorted (draft_capture_sizes_prefill )
1027+
10021028 def filter_capture_size (self , tp_size : int = 1 ):
10031029 """When TSP is used, capture size must be divisible by tp size."""
10041030 self .cudagraph_capture_sizes = [
10051031 draft_size for draft_size in self .cudagraph_capture_sizes if (draft_size % tp_size == 0 )
10061032 ]
1033+ self .cudagraph_capture_sizes_prefill = [
1034+ draft_size for draft_size in self .cudagraph_capture_sizes_prefill if (draft_size % tp_size == 0 )
1035+ ]
10071036
10081037 def to_json_string (self ):
10091038 """
@@ -1672,8 +1701,7 @@ def __init__(
16721701 else :
16731702 max_capture_shape = min (512 , max_capture_shape )
16741703
1675- if self .graph_opt_config .graph_opt_level > 0 :
1676- max_capture_shape = graph_opt_config .max_capture_shape_dy2st
1704+ max_capture_shape_prefill = graph_opt_config .max_capture_shape_prefill
16771705
16781706 if self .graph_opt_config .cudagraph_capture_sizes is None :
16791707 dec_token_per_query_per_step = (
@@ -1682,9 +1710,14 @@ def __init__(
16821710 else 1
16831711 )
16841712 self .graph_opt_config ._set_cudagraph_sizes (
1685- max_capture_size = max_capture_shape , dec_token_per_query_per_step = dec_token_per_query_per_step
1713+ max_capture_size = max_capture_shape ,
1714+ max_capture_shape_prefill = max_capture_shape_prefill ,
1715+ dec_token_per_query_per_step = dec_token_per_query_per_step ,
16861716 )
1687- self .graph_opt_config .init_with_cudagrpah_size (max_capture_size = max_capture_shape )
1717+ self .graph_opt_config .init_with_cudagrpah_size (
1718+ max_capture_size = max_capture_shape ,
1719+ max_capture_shape_prefill = max_capture_shape_prefill ,
1720+ )
16881721
16891722 self .tokenizer = tokenizer
16901723 self .ips = ips
0 commit comments