Skip to content

Commit 95b5af2

Browse files
authored
[SOT] Add sot warmup (NVIDIA GPU Only) (PaddlePaddle#2929)
* add sot warmup * fix code style * change batch_size list * add param to config * rm free_list settings && set sot_warmup_sizes * finish debug with dynamic dims by type annotations * add profile_run guard * rm sth useless
1 parent 7c5e34e commit 95b5af2

7 files changed

Lines changed: 71 additions & 18 deletions

File tree

fastdeploy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ class GraphOptimizationConfig:
319319
- With dyncmic graph backend: ...
320320
- With static grpah backend: WIP
321321
"""
322+
sot_warmup_sizes: Optional[list[int]] = field(default_factory=list)
323+
""" Number of warmup runs for SOT warmup. """
322324
use_cudagraph: bool = False
323325
"""Sizes to capture cudagraph.
324326
- None (default): capture sizes are inferred from llm config.

fastdeploy/engine/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def __init__(
429429
graph_opt_level: Optional[int] = 0,
430430
use_cudagraph: Optional[bool] = None,
431431
cudagraph_capture_sizes: Optional[List[int]] = None,
432+
sot_warmup_sizes: Optional[List[int]] = None,
432433
**kwargs,
433434
):
434435
"""
@@ -444,6 +445,7 @@ def __init__(
444445
self.graph_opt_level = graph_opt_level
445446
self.use_cudagraph = use_cudagraph
446447
self.cudagraph_capture_sizes = cudagraph_capture_sizes
448+
self.sot_warmup_sizes = [] if sot_warmup_sizes is None else sot_warmup_sizes
447449

448450
def to_json_string(self):
449451
"""

fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,15 @@
3131
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
3232
resolve_dynamic_dims,
3333
)
34+
from fastdeploy.model_executor.graph_optimization.utils import in_profile_run_mode
35+
from fastdeploy.model_executor.graph_optimization.utils import (
36+
in_sot_warmup_mode as in_warmup_mode,
37+
)
3438

3539
P = ParamSpec("P")
3640
T = TypeVar("T")
3741

3842

39-
# TODO(SigureMo): Replace this fn with real implementation by DrRyanHuang
40-
def create_in_warmup_mode():
41-
cnt = 0
42-
43-
def in_warmup_mode():
44-
nonlocal cnt
45-
cnt += 1
46-
return cnt < 32
47-
48-
return in_warmup_mode
49-
50-
51-
in_warmup_mode = create_in_warmup_mode()
52-
53-
5443
def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -> Callable[P, T]:
5544
forward_fn = fn
5645
forward_sig = inspect.signature(forward_fn)
@@ -99,6 +88,8 @@ def warmup_impl(self, *args, **kwargs):
9988

10089
@functools.wraps(forward_fn)
10190
def static_forward(self, *args, **kwargs):
91+
if in_profile_run_mode():
92+
return forward_fn(self, *args, **kwargs)
10293
nonlocal need_warmup
10394
is_warmup = in_warmup_mode() and need_warmup
10495
if is_warmup:
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import contextlib
18+
19+
20+
def create_guard(default_value):
21+
_state = default_value
22+
23+
@contextlib.contextmanager
24+
def state_guard(current_state):
25+
nonlocal _state
26+
old_state = _state
27+
_state = current_state
28+
try:
29+
yield
30+
finally:
31+
_state = old_state
32+
33+
def get_state():
34+
return _state
35+
36+
return state_guard, get_state
37+
38+
39+
sot_warmup_guard, in_sot_warmup_mode = create_guard(False)
40+
profile_run_guard, in_profile_run_mode = create_guard(False)

fastdeploy/worker/gpu_model_runner.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
from fastdeploy.config import FDConfig
2727
from fastdeploy.engine.request import Request
28+
from fastdeploy.model_executor.graph_optimization.utils import (
29+
profile_run_guard,
30+
sot_warmup_guard,
31+
)
2832
from fastdeploy.model_executor.guided_decoding import get_guided_backend
2933
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
3034
LogitsProcessorBase,
@@ -113,8 +117,10 @@ def __init__(
113117
# self.kv_caches: list[paddle.Tensor] = []
114118

115119
# Cuda Graph
120+
self.graph_opt_level = self.graph_opt_config.graph_opt_level
116121
self.use_cudagraph = self.graph_opt_config.use_cudagraph
117122
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
123+
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
118124

119125
# Initialize share inputs
120126
self._init_share_inputs(self.parallel_config.max_num_seqs)
@@ -367,9 +373,6 @@ def get_attr_from_request(request, attr, default_value=None):
367373
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
368374
"""Set dummy prefill inputs to share_inputs"""
369375
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
370-
if self.enable_mm:
371-
self.share_inputs["free_list"] = paddle.to_tensor([], dtype="int32")
372-
self.share_inputs["free_list_len"][0] = 0
373376
max_dec_len = expected_decode_len + 1
374377
full_length = min(
375378
num_tokens // batch_size,
@@ -1007,6 +1010,17 @@ def capture_model(self) -> None:
10071010
time_after_capture = time.perf_counter()
10081011
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
10091012

1013+
@sot_warmup_guard(True)
1014+
def sot_warmup(self) -> None:
1015+
start_time = time.perf_counter()
1016+
for batch_size in self.sot_warmup_sizes:
1017+
self._dummy_run(
1018+
num_tokens=self.parallel_config.max_num_batched_tokens,
1019+
batch_size=batch_size,
1020+
)
1021+
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
1022+
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
1023+
10101024
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
10111025
"""
10121026
Get the index of the request that needs to be skipped during execution.
@@ -1208,6 +1222,7 @@ def _execute_empty_input(self) -> None:
12081222
else:
12091223
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
12101224

1225+
@profile_run_guard(True)
12111226
def profile_run(self) -> None:
12121227
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
12131228

fastdeploy/worker/gpu_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def graph_optimize_and_warm_up_model(self) -> None:
189189
"""
190190
Perform the warm-up and the graph optimization
191191
"""
192+
if self.model_runner.graph_opt_level >= 1:
193+
self.model_runner.sot_warmup()
192194
# Triger cuda grpah capture
193195
self.model_runner.capture_model()
194196

fastdeploy/worker/worker_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
632632
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
633633
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
634634
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"],
635+
sot_warmup_sizes=args.graph_optimization_config["sot_warmup_sizes"],
635636
)
636637

637638
# Note(tangbinhan): used for load_checkpoint

0 commit comments

Comments
 (0)