Skip to content

Commit 8d27a52

Browse files
liyonghua0910Jiang-Jia-JunCopilot
authored
[Feature] [KVCache] support attention_store kv cache backend (PaddlePaddle#5823)
* [feat] support attention_store kv cache backend * [fix] fix codestyle * [chore] optimize log * [fix] fix write storage task * [fix] fix read storage * [fix] fix code conflict after merge develop * [fix] fix cache bytes and read task token ids * [chore] add model for cache transfer manager * [chore] add some log * [chore] remove launched_cache_manager_signal * [fix] fix write_back_storage_task match_block_num condition * [fix] fix swap_cost_time * [ci] fix ci * Update fastdeploy/engine/sched/resource_manager_v1.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/cache_transfer_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3cd0ffe commit 8d27a52

17 files changed

Lines changed: 603 additions & 230 deletions

docs/zh/online_serving/metrics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
| KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 |
3333
| KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 ||
3434
| KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 ||
35-
| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 ||
35+
| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 CPU 上的前缀缓存 token 总数 ||
3636
| KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)||
3737
| KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 ||
3838
| KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 ||

fastdeploy/cache_manager/cache_messager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,10 @@ def main():
10511051

10521052
args = parse_args()
10531053
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
1054-
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
1054+
if args.mp_num > 1:
1055+
logger = get_logger("cache_messager", f"cache_messager_{rank_id}.log")
1056+
else:
1057+
logger = get_logger("cache_messager", "cache_messager.log")
10551058

10561059
logger.info("create cache messager...")
10571060
logger.info(f"{args}")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from dataclasses import dataclass
18+
from typing import List
19+
20+
21+
@dataclass(frozen=True, kw_only=True)
22+
class CacheTask:
23+
task_id: str
24+
keys: List[str]
25+
token_ids: List[int]
26+
gpu_block_ids: List[int]
27+
28+
29+
@dataclass(frozen=True, kw_only=True)
30+
class ReadStorageTask(CacheTask):
31+
start_read_block_idx: int
32+
timeout: float = 30.0
33+
34+
35+
@dataclass(frozen=True, kw_only=True)
36+
class WriteStorageTask(CacheTask):
37+
timeout: float = 30.0

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 276 additions & 196 deletions
Large diffs are not rendered by default.

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
from fastdeploy import envs
3232
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
3333
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
34+
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
3435
from fastdeploy.cache_manager.ops import get_all_visible_devices
36+
from fastdeploy.config import FDConfig
3537
from fastdeploy.engine.request import Request
3638
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
3739
from fastdeploy.metrics.metrics import main_process_metrics
@@ -47,7 +49,7 @@ class PrefixCacheManager:
4749

4850
def __init__(
4951
self,
50-
config,
52+
config: FDConfig,
5153
tensor_parallel_size,
5254
splitwise_role="mixed",
5355
local_data_parallel_id=0,
@@ -207,7 +209,6 @@ def launch_cache_manager(
207209
key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num)
208210
key_cache_shape = ",".join([str(i) for i in key_cache_shape])
209211
val_cache_shape = ",".join([str(i) for i in val_cache_shape])
210-
logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}")
211212
if self.enable_splitwise:
212213
cache_messager_processes = self.launch_cache_messager(
213214
cache_config,
@@ -273,6 +274,7 @@ def launch_cache_manager(
273274
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
274275
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
275276
+ f" {sys.executable} {py_path}"
277+
+ f" --model_id {os.path.basename(self.config.model_config.model)}"
276278
+ f" --device_id {int(device_ids[i])}"
277279
+ f" --rank {i}"
278280
+ f" --splitwise_role {self.splitwise_role}"
@@ -390,7 +392,7 @@ def launch_cache_messager(
390392
+ f" --ipc_suffix {ipc_suffix}"
391393
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
392394
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
393-
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
395+
+ f" >{log_dir}/launch_cache_messager_{i}.log 2>&1"
394396
)
395397
logger.info(f"Launch cache messager, command:{launch_cmd}")
396398
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
@@ -789,9 +791,15 @@ def request_match_blocks(self, task: Request, block_size, *args):
789791
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
790792
)
791793
start_time = time.time()
792-
storage_matched_block_ids = self.issue_prefetch_storage_task(
793-
req_id, no_match_block_keys, gpu_recv_storage_block_ids
794+
read_storage_task = ReadStorageTask(
795+
task_id=req_id,
796+
keys=no_match_block_keys,
797+
token_ids=input_token_ids,
798+
gpu_block_ids=gpu_recv_storage_block_ids,
799+
start_read_block_idx=match_token_num // block_size,
794800
)
801+
logger.debug(f"issue read storage task: {read_storage_task}")
802+
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
795803
storage_matched_block_num = len(storage_matched_block_ids)
796804
storage_match_token_num = storage_matched_block_num * block_size
797805
cost_time = time.time() - start_time
@@ -1006,6 +1014,12 @@ def write_cache_to_storage(self, request: Request):
10061014
if self.kvcache_storage_backend is None:
10071015
return
10081016

1017+
token_ids = request.prompt_token_ids
1018+
if isinstance(token_ids, np.ndarray):
1019+
token_ids = token_ids.tolist()
1020+
if self.config.cache_config.enable_output_caching:
1021+
token_ids += request.output_token_ids
1022+
10091023
req_id = request.request_id
10101024
keys = []
10111025
node = self.req_leaf_map[req_id]
@@ -1018,24 +1032,33 @@ def write_cache_to_storage(self, request: Request):
10181032

10191033
gpu_block_ids = request.block_tables[: len(keys)]
10201034
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
1035+
write_storage_task = WriteStorageTask(
1036+
task_id=req_id,
1037+
keys=keys,
1038+
token_ids=token_ids,
1039+
gpu_block_ids=gpu_block_ids,
1040+
)
1041+
logger.debug(f"issue write storage task: {write_storage_task}")
10211042
tic = time.time()
1022-
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
1043+
self.issue_write_back_storage_task(write_storage_task, is_sync=True)
10231044
cost_time = time.time() - tic
10241045
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
10251046

1026-
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
1047+
def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True):
10271048
if self.kvcache_storage_backend is None:
10281049
return
10291050

1030-
if len(hash_keys) != len(gpu_block_ids):
1031-
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
1051+
if len(task.keys) != len(task.gpu_block_ids):
1052+
err_msg = (
1053+
f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})"
1054+
)
10321055
logger.error(err_msg)
10331056
raise ValueError(err_msg)
10341057

1035-
self.task_write_back_event[req_id] = Event()
1036-
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
1058+
self.task_write_back_event[task.task_id] = Event()
1059+
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task))
10371060
if is_sync:
1038-
self.wait_write_storage_task(req_id)
1061+
self.wait_write_storage_task(task.task_id)
10391062

10401063
def wait_write_storage_task(self, req_id):
10411064
"""
@@ -1045,16 +1068,19 @@ def wait_write_storage_task(self, req_id):
10451068
self.task_write_back_event[req_id].wait()
10461069
del self.task_write_back_event[req_id]
10471070

1048-
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
1071+
def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True):
10491072
"""
10501073
Prefetch cache from storage task
10511074
"""
1075+
if self.kvcache_storage_backend is None:
1076+
return []
1077+
10521078
storage_block_ids = []
1053-
self.task_prefetch_event[req_id] = Event()
1079+
self.task_prefetch_event[task.task_id] = Event()
10541080
# issue task to cache_transfer_manager
1055-
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
1081+
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task))
10561082
if is_sync:
1057-
storage_block_ids = self.wait_prefetch_storage_task(req_id)
1083+
storage_block_ids = self.wait_prefetch_storage_task(task.task_id)
10581084
return storage_block_ids
10591085

10601086
def wait_prefetch_storage_task(self, req_id):

fastdeploy/cache_manager/transfer_factory/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from fastdeploy.platforms import current_platform
1818

1919
from .kvcache_storage import KVCacheStorage
20-
from .mooncake_store import MooncakeStore
20+
from .mooncake_store import AttentionStore, MooncakeStore
2121
from .rdma_cache_transfer import RDMACommManager
2222

2323
if current_platform.is_cuda():
@@ -31,4 +31,5 @@
3131
"RDMACommManager",
3232
"KVCacheStorage",
3333
"MooncakeStore",
34+
"AttentionStore",
3435
]

fastdeploy/cache_manager/transfer_factory/kvcache_storage.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,10 @@ def clear(self) -> bool:
9595
Clear all keys in storage
9696
"""
9797
pass
98+
99+
@abstractmethod
100+
def query(self) -> int:
101+
"""
102+
Query the number of blocks stored in the storage.
103+
"""
104+
pass

fastdeploy/cache_manager/transfer_factory/mooncake_store/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
from .attention_store import AttentionStore
1718
from .mooncake_store import MooncakeStore
1819

19-
__all__ = ["MooncakeStore"]
20+
__all__ = ["MooncakeStore", "AttentionStore"]

0 commit comments

Comments
 (0)