|
25 | 25 |
|
26 | 26 | from fastdeploy.config import FDConfig |
27 | 27 | from fastdeploy.engine.request import Request |
| 28 | +from fastdeploy.model_executor.graph_optimization.utils import ( |
| 29 | + profile_run_guard, |
| 30 | + sot_warmup_guard, |
| 31 | +) |
28 | 32 | from fastdeploy.model_executor.guided_decoding import get_guided_backend |
29 | 33 | from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( |
30 | 34 | LogitsProcessorBase, |
@@ -113,8 +117,10 @@ def __init__( |
113 | 117 | # self.kv_caches: list[paddle.Tensor] = [] |
114 | 118 |
|
115 | 119 | # Cuda Graph |
| 120 | + self.graph_opt_level = self.graph_opt_config.graph_opt_level |
116 | 121 | self.use_cudagraph = self.graph_opt_config.use_cudagraph |
117 | 122 | self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) |
| 123 | + self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes |
118 | 124 |
|
119 | 125 | # Initialize share inputs |
120 | 126 | self._init_share_inputs(self.parallel_config.max_num_seqs) |
@@ -367,9 +373,6 @@ def get_attr_from_request(request, attr, default_value=None): |
367 | 373 | def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): |
368 | 374 | """Set dummy prefill inputs to share_inputs""" |
369 | 375 | # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token |
370 | | - if self.enable_mm: |
371 | | - self.share_inputs["free_list"] = paddle.to_tensor([], dtype="int32") |
372 | | - self.share_inputs["free_list_len"][0] = 0 |
373 | 376 | max_dec_len = expected_decode_len + 1 |
374 | 377 | full_length = min( |
375 | 378 | num_tokens // batch_size, |
@@ -1007,6 +1010,17 @@ def capture_model(self) -> None: |
1007 | 1010 | time_after_capture = time.perf_counter() |
1008 | 1011 | logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") |
1009 | 1012 |
|
| 1013 | + @sot_warmup_guard(True) |
| 1014 | + def sot_warmup(self) -> None: |
| 1015 | + start_time = time.perf_counter() |
| 1016 | + for batch_size in self.sot_warmup_sizes: |
| 1017 | + self._dummy_run( |
| 1018 | + num_tokens=self.parallel_config.max_num_batched_tokens, |
| 1019 | + batch_size=batch_size, |
| 1020 | + ) |
| 1021 | + logger.info(f"SOT warmup the model with the batch size:{batch_size}") |
| 1022 | + logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") |
| 1023 | + |
1010 | 1024 | def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None): |
1011 | 1025 | """ |
1012 | 1026 | Get the index of the request that needs to be skipped during execution. |
@@ -1208,6 +1222,7 @@ def _execute_empty_input(self) -> None: |
1208 | 1222 | else: |
1209 | 1223 | raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") |
1210 | 1224 |
|
| 1225 | + @profile_run_guard(True) |
1211 | 1226 | def profile_run(self) -> None: |
1212 | 1227 | """Execute a forward pass with dummy inputs to profile the memory usage of the model""" |
1213 | 1228 |
|
|
0 commit comments