22Job related helpers.
33"""
44
5- # pylint: disable=too-many-branches
6-
75import asyncio
86import inspect
97import json
108import os
119import traceback
12- from typing import Any , AsyncGenerator , Callable , Dict , Optional , Union
10+ from typing import Any , AsyncGenerator , Callable , Dict , Optional , Union , List
1311
1412from runpod .http_client import ClientSession
1513from runpod .serverless .modules .rp_logger import RunPodLogger
1614
1715from ...version import __version__ as runpod_version
1816from .rp_tips import check_return_size
19- from .worker_state import WORKER_ID , Jobs
17+ from .worker_state import WORKER_ID , JobsQueue
2018
2119JOB_GET_URL = str (os .environ .get ("RUNPOD_WEBHOOK_GET_JOB" )).replace ("$ID" , WORKER_ID )
2220
2321log = RunPodLogger ()
24- job_list = Jobs ()
22+ job_list = JobsQueue ()
2523
2624
27- def _job_get_url ():
25+ def _job_get_url (batch_size : int = 1 ):
2826 """
2927 Prepare the URL for making a 'get' request to the serverless API (sls).
3028
@@ -34,89 +32,68 @@ def _job_get_url():
3432 Returns:
3533 str: The prepared URL for the 'get' request to the serverless API.
3634 """
37- job_in_progress = "1" if job_list .get_job_list () else "0"
38- return JOB_GET_URL + f"&job_in_progress={ job_in_progress } "
35+ job_in_progress = "1" if job_list .get_job_count () else "0"
36+
37+ if batch_size > 1 :
38+ job_take_url = JOB_GET_URL .replace ("/job-take/" , "/job-take-batch/" )
39+ job_take_url += f"&batch_size={ batch_size } &batch_strategy=LMove"
40+ else :
41+ job_take_url = JOB_GET_URL
3942
43+ return job_take_url + f"&job_in_progress={ job_in_progress } "
4044
41- async def get_job (session : ClientSession , retry = True ) -> Optional [Dict [str , Any ]]:
45+
46+ async def get_job (
47+ session : ClientSession , num_jobs : int = 1
48+ ) -> Optional [List [Dict [str , Any ]]]:
4249 """
43- Get the job from the queue.
44- Will continue trying to get a job until one is available.
50+ Get a job from the job-take API.
51+
52+ `num_jobs = 1` will query the legacy singular job-take API.
53+
54+ `num_jobs > 1` will query the batch job-take API.
4555
4656 Args:
47- session (ClientSession): The async http client to use for the request.
48- retry (bool ): Whether to retry if no job is available .
57+ session (ClientSession): The aiohttp ClientSession to use for the request.
58+ num_jobs (int ): The number of jobs to get .
4959 """
50- next_job = None
51-
52- while next_job is None :
53- try :
54- async with session .get (_job_get_url ()) as response :
55- if response .status == 204 :
56- log .debug ("No content, no job to process." )
57- if retry is False :
58- break
59- continue
60-
61- if response .status == 400 :
62- log .debug (
63- "Received 400 status, expected when FlashBoot is enabled."
64- )
65- if retry is False :
66- break
67- continue
68-
69- if response .status != 200 :
70- log .error (f"Failed to get job, status code: { response .status } " )
71- if retry is False :
72- break
73- continue
74-
75- received_request = await response .json ()
76- log .debug (f"Request Received | { received_request } " )
77-
78- # Check if the job is valid
79- job_id = received_request .get ("id" , None )
80- job_input = received_request .get ("input" , None )
81-
82- if None in [job_id , job_input ]:
83- missing_fields = []
84- if job_id is None :
85- missing_fields .append ("id" )
86- if job_input is None :
87- missing_fields .append ("input" )
88-
89- log .error (f"Job has missing field(s): { ', ' .join (missing_fields )} ." )
90- else :
91- next_job = received_request
92-
93- except asyncio .TimeoutError :
94- log .debug ("Timeout error, retrying." )
95- if retry is False :
96- break
97-
98- except Exception as err : # pylint: disable=broad-except
99- err_type = type (err ).__name__
100- err_message = str (err )
101- err_traceback = traceback .format_exc ()
102- log .error (
103- f"Failed to get job. | Error Type: { err_type } | Error Message: { err_message } "
104- )
105- log .error (f"Traceback: { err_traceback } " )
106-
107- if next_job is None :
108- log .debug ("No job available, waiting for the next one." )
109- if retry is False :
110- break
111-
112- await asyncio .sleep (1 )
113- else :
114- job_list .add_job (next_job ["id" ])
115- log .debug ("Request ID added." , next_job ["id" ])
116-
117- return next_job
60+ try :
61+ async with session .get (_job_get_url (num_jobs )) as response :
62+ if response .status == 204 :
63+ log .debug ("No content, no job to process." )
64+ return
65+
66+ if response .status == 400 :
67+ log .debug ("Received 400 status, expected when FlashBoot is enabled." )
68+ return
69+
70+ if response .status != 200 :
71+ log .error (f"Failed to get job, status code: { response .status } " )
72+ return
73+
74+ jobs = await response .json ()
75+ log .debug (f"Request Received | { jobs } " )
76+
77+ # legacy job-take API
78+ if isinstance (jobs , dict ):
79+ if "id" not in jobs or "input" not in jobs :
80+ raise Exception ("Job has missing field(s): id or input." )
81+ return [jobs ]
82+
83+ # batch job-take API
84+ if isinstance (jobs , list ):
85+ return jobs
86+
87+ except asyncio .TimeoutError :
88+ log .debug ("Timeout error, retrying." )
89+
90+ except Exception as error :
91+ log .error (
92+ f"Failed to get job. | Error Type: { type (error ).__name__ } | Error Message: { str (error )} "
93+ )
11894
119- return None
95+ # empty
96+ return []
12097
12198
12299async def run_job (handler : Callable , job : Dict [str , Any ]) -> Dict [str , Any ]:
@@ -164,7 +141,7 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
164141
165142 check_return_size (run_result ) # Checks the size of the return body.
166143
167- except Exception as err : # pylint: disable=broad-except
144+ except Exception as err :
168145 error_info = {
169146 "error_type" : str (type (err )),
170147 "error_message" : str (err ),
@@ -209,7 +186,7 @@ async def run_job_generator(
209186 log .debug (f"Generator output: { output_partial } " , job ["id" ])
210187 yield {"output" : output_partial }
211188
212- except Exception as err : # pylint: disable=broad-except
189+ except Exception as err :
213190 log .error (err , job ["id" ])
214191 yield {"error" : f"handler: { str (err )} \n traceback: { traceback .format_exc ()} " }
215192 finally :
0 commit comments