3131from fastdeploy import envs
3232from fastdeploy .cache_manager .cache_data import BlockNode , CacheStatus
3333from fastdeploy .cache_manager .cache_metrics import CacheMetrics
34+ from fastdeploy .cache_manager .cache_tasks import ReadStorageTask , WriteStorageTask
3435from fastdeploy .cache_manager .ops import get_all_visible_devices
36+ from fastdeploy .config import FDConfig
3537from fastdeploy .engine .request import Request
3638from fastdeploy .inter_communicator import EngineCacheQueue , IPCSignal , PrefixTreeStatus
3739from fastdeploy .metrics .metrics import main_process_metrics
@@ -47,7 +49,7 @@ class PrefixCacheManager:
4749
4850 def __init__ (
4951 self ,
50- config ,
52+ config : FDConfig ,
5153 tensor_parallel_size ,
5254 splitwise_role = "mixed" ,
5355 local_data_parallel_id = 0 ,
@@ -207,7 +209,6 @@ def launch_cache_manager(
207209 key_cache_shape , val_cache_shape = self ._get_kv_cache_shape (cache_config .total_block_num )
208210 key_cache_shape = "," .join ([str (i ) for i in key_cache_shape ])
209211 val_cache_shape = "," .join ([str (i ) for i in val_cache_shape ])
210- logger .info (f"key_cache_shape { key_cache_shape } value_cache_shape { val_cache_shape } " )
211212 if self .enable_splitwise :
212213 cache_messager_processes = self .launch_cache_messager (
213214 cache_config ,
@@ -273,6 +274,7 @@ def launch_cache_manager(
273274 + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
274275 + f" FD_ENABLE_SWAP_SPACE_CLEARING={ envs .FD_ENABLE_SWAP_SPACE_CLEARING } "
275276 + f" { sys .executable } { py_path } "
277+ + f" --model_id { os .path .basename (self .config .model_config .model )} "
276278 + f" --device_id { int (device_ids [i ])} "
277279 + f" --rank { i } "
278280 + f" --splitwise_role { self .splitwise_role } "
@@ -390,7 +392,7 @@ def launch_cache_messager(
390392 + f" --ipc_suffix { ipc_suffix } "
391393 + f" --rdma_port { cache_config .local_rdma_comm_ports [i ] if cache_config .local_rdma_comm_ports is not None else '0' } "
392394 + f" --speculative_config '{ self .speculative_config .to_json_string ()} '"
393- + f" >{ log_dir } /launch_cache_messager_tprank { i } .log 2>&1"
395+ + f" >{ log_dir } /launch_cache_messager_ { i } .log 2>&1"
394396 )
395397 logger .info (f"Launch cache messager, command:{ launch_cmd } " )
396398 cache_messager_processes .append (subprocess .Popen (launch_cmd , shell = True , preexec_fn = os .setsid ))
@@ -789,9 +791,15 @@ def request_match_blocks(self, task: Request, block_size, *args):
789791 f"start prefetch cache from storage, req_id: { req_id } , block num: { len (no_match_block_keys )} "
790792 )
791793 start_time = time .time ()
792- storage_matched_block_ids = self .issue_prefetch_storage_task (
793- req_id , no_match_block_keys , gpu_recv_storage_block_ids
794+ read_storage_task = ReadStorageTask (
795+ task_id = req_id ,
796+ keys = no_match_block_keys ,
797+ token_ids = input_token_ids ,
798+ gpu_block_ids = gpu_recv_storage_block_ids ,
799+ start_read_block_idx = match_token_num // block_size ,
794800 )
801+ logger .debug (f"issue read storage task: { read_storage_task } " )
802+ storage_matched_block_ids = self .issue_prefetch_storage_task (read_storage_task )
795803 storage_matched_block_num = len (storage_matched_block_ids )
796804 storage_match_token_num = storage_matched_block_num * block_size
797805 cost_time = time .time () - start_time
@@ -1006,6 +1014,12 @@ def write_cache_to_storage(self, request: Request):
10061014 if self .kvcache_storage_backend is None :
10071015 return
10081016
1017+ token_ids = request .prompt_token_ids
1018+ if isinstance (token_ids , np .ndarray ):
1019+ token_ids = token_ids .tolist ()
1020+ if self .config .cache_config .enable_output_caching :
1021+ token_ids += request .output_token_ids
1022+
10091023 req_id = request .request_id
10101024 keys = []
10111025 node = self .req_leaf_map [req_id ]
@@ -1018,24 +1032,33 @@ def write_cache_to_storage(self, request: Request):
10181032
10191033 gpu_block_ids = request .block_tables [: len (keys )]
10201034 logger .info (f"start write cache back to storage, req_id: { req_id } , block num: { len (keys )} " )
1035+ write_storage_task = WriteStorageTask (
1036+ task_id = req_id ,
1037+ keys = keys ,
1038+ token_ids = token_ids ,
1039+ gpu_block_ids = gpu_block_ids ,
1040+ )
1041+ logger .debug (f"issue write storage task: { write_storage_task } " )
10211042 tic = time .time ()
1022- self .issue_write_back_storage_task (req_id = req_id , hash_keys = keys , gpu_block_ids = gpu_block_ids , is_sync = True )
1043+ self .issue_write_back_storage_task (write_storage_task , is_sync = True )
10231044 cost_time = time .time () - tic
10241045 logger .info (f"finish write cache back to storage, req_id: { req_id } , cost_time: { cost_time :.6f} s" )
10251046
1026- def issue_write_back_storage_task (self , req_id , hash_keys , gpu_block_ids , is_sync = True , timeout = 0.5 ):
1047+ def issue_write_back_storage_task (self , task : WriteStorageTask , is_sync = True ):
10271048 if self .kvcache_storage_backend is None :
10281049 return
10291050
1030- if len (hash_keys ) != len (gpu_block_ids ):
1031- err_msg = f"write_back_storage error: hash_keys({ len (hash_keys )} ) != gpu_block_ids({ len (gpu_block_ids )} )"
1051+ if len (task .keys ) != len (task .gpu_block_ids ):
1052+ err_msg = (
1053+ f"write_back_storage error: hash_keys({ len (task .keys )} ) != gpu_block_ids({ len (task .gpu_block_ids )} )"
1054+ )
10321055 logger .error (err_msg )
10331056 raise ValueError (err_msg )
10341057
1035- self .task_write_back_event [req_id ] = Event ()
1036- self .cache_task_queue .put_transfer_task ((CacheStatus .GPU2STORAGE , req_id , hash_keys , gpu_block_ids , timeout ))
1058+ self .task_write_back_event [task . task_id ] = Event ()
1059+ self .cache_task_queue .put_transfer_task ((CacheStatus .GPU2STORAGE , task ))
10371060 if is_sync :
1038- self .wait_write_storage_task (req_id )
1061+ self .wait_write_storage_task (task . task_id )
10391062
10401063 def wait_write_storage_task (self , req_id ):
10411064 """
@@ -1045,16 +1068,19 @@ def wait_write_storage_task(self, req_id):
10451068 self .task_write_back_event [req_id ].wait ()
10461069 del self .task_write_back_event [req_id ]
10471070
1048- def issue_prefetch_storage_task (self , req_id , hash_keys , gpu_block_ids , is_sync = True , timeout = 0.5 ):
1071+ def issue_prefetch_storage_task (self , task : ReadStorageTask , is_sync = True ):
10491072 """
10501073 Prefetch cache from storage task
10511074 """
1075+ if self .kvcache_storage_backend is None :
1076+ return []
1077+
10521078 storage_block_ids = []
1053- self .task_prefetch_event [req_id ] = Event ()
1079+ self .task_prefetch_event [task . task_id ] = Event ()
10541080 # issue task to cache_transfer_manager
1055- self .cache_task_queue .put_transfer_task ((CacheStatus .STORAGE2GPU , req_id , hash_keys , gpu_block_ids , timeout ))
1081+ self .cache_task_queue .put_transfer_task ((CacheStatus .STORAGE2GPU , task ))
10561082 if is_sync :
1057- storage_block_ids = self .wait_prefetch_storage_task (req_id )
1083+ storage_block_ids = self .wait_prefetch_storage_task (task . task_id )
10581084 return storage_block_ids
10591085
10601086 def wait_prefetch_storage_task (self , req_id ):
0 commit comments