Skip to content

Commit f1e6450

Browse files
authored
Batch Concurrency (#358)
* Restored concurrency_modifier support * Concurrent calls now use batched job takes instead of multiple parallel singular job takes * Using asyncio.Queue to track jobs queue and processing * JobScaler refactored to properly coordinate job takes and processing
1 parent ca2bc00 commit f1e6450

16 files changed

Lines changed: 444 additions & 348 deletions

File tree

.github/workflows/CI-pytests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ jobs:
3232
pip install '.[test]'
3333
3434
- name: Run Tests
35-
run: pytest --cov-config=.coveragerc --timeout=120 --timeout_method=thread --cov=runpod --cov-report=xml --cov-report=term-missing --cov-fail-under=98 -W error -p no:cacheprovider -p no:unraisableexception
35+
run: pytest

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ test_logging.py
88
/example_project
99
IGNORE.py
1010
/quick-test
11+
.DS_Store
1112

1213
# Byte-compiled / optimized / DLL files
1314
__pycache__/

examples/serverless/concurrent_handler.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,18 @@ async def async_generator_handler(job):
1111
return job
1212

1313

14+
# --------------------------- Concurrency Modifier --------------------------- #
15+
def concurrency_modifier(current_concurrency=1):
16+
"""
17+
Concurrency modifier.
18+
"""
19+
desired_concurrency = current_concurrency
20+
21+
# Do some logic to determine the desired concurrency.
22+
23+
return desired_concurrency
24+
25+
1426
runpod.serverless.start(
15-
{
16-
"handler": async_generator_handler,
17-
}
27+
{"handler": async_generator_handler, "concurrency_modifier": concurrency_modifier}
1828
)

runpod/serverless/modules/rp_fastapi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
""" Used to launch the FastAPI web server when worker is running in API mode. """
22

3-
# pylint: disable=too-few-public-methods, line-too-long
4-
53
import os
64
import threading
75
import uuid

runpod/serverless/modules/rp_http.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from runpod.http_client import ClientSession
1212
from runpod.serverless.modules.rp_logger import RunPodLogger
1313

14-
from .worker_state import WORKER_ID, Jobs
14+
from .worker_state import WORKER_ID
1515

1616
JOB_DONE_URL_TEMPLATE = str(
1717
os.environ.get("RUNPOD_WEBHOOK_POST_OUTPUT", "JOB_DONE_URL")
@@ -24,7 +24,6 @@
2424
JOB_STREAM_URL = JOB_STREAM_URL_TEMPLATE.replace("$RUNPOD_POD_ID", WORKER_ID)
2525

2626
log = RunPodLogger()
27-
job_list = Jobs()
2827

2928

3029
async def _transmit(client_session: ClientSession, url, job_data):
@@ -49,7 +48,6 @@ async def _transmit(client_session: ClientSession, url, job_data):
4948
await client_response.text()
5049

5150

52-
# pylint: disable=too-many-arguments, disable=line-too-long
5351
async def _handle_result(
5452
session: ClientSession, job_data, job, url_template, log_message, is_stream=False
5553
):
@@ -79,7 +77,6 @@ async def _handle_result(
7977
url_template == JOB_DONE_URL
8078
and job_data.get("status", None) != "IN_PROGRESS"
8179
):
82-
job_list.remove_job(job["id"])
8380
log.info("Finished.", job["id"])
8481

8582

runpod/serverless/modules/rp_job.py

Lines changed: 61 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,27 @@
22
Job related helpers.
33
"""
44

5-
# pylint: disable=too-many-branches
6-
75
import asyncio
86
import inspect
97
import json
108
import os
119
import traceback
12-
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union
10+
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, List
1311

1412
from runpod.http_client import ClientSession
1513
from runpod.serverless.modules.rp_logger import RunPodLogger
1614

1715
from ...version import __version__ as runpod_version
1816
from .rp_tips import check_return_size
19-
from .worker_state import WORKER_ID, Jobs
17+
from .worker_state import WORKER_ID, JobsQueue
2018

2119
JOB_GET_URL = str(os.environ.get("RUNPOD_WEBHOOK_GET_JOB")).replace("$ID", WORKER_ID)
2220

2321
log = RunPodLogger()
24-
job_list = Jobs()
22+
job_list = JobsQueue()
2523

2624

27-
def _job_get_url():
25+
def _job_get_url(batch_size: int = 1):
2826
"""
2927
Prepare the URL for making a 'get' request to the serverless API (sls).
3028
@@ -34,89 +32,68 @@ def _job_get_url():
3432
Returns:
3533
str: The prepared URL for the 'get' request to the serverless API.
3634
"""
37-
job_in_progress = "1" if job_list.get_job_list() else "0"
38-
return JOB_GET_URL + f"&job_in_progress={job_in_progress}"
35+
job_in_progress = "1" if job_list.get_job_count() else "0"
36+
37+
if batch_size > 1:
38+
job_take_url = JOB_GET_URL.replace("/job-take/", "/job-take-batch/")
39+
job_take_url += f"&batch_size={batch_size}&batch_strategy=LMove"
40+
else:
41+
job_take_url = JOB_GET_URL
3942

43+
return job_take_url + f"&job_in_progress={job_in_progress}"
4044

41-
async def get_job(session: ClientSession, retry=True) -> Optional[Dict[str, Any]]:
45+
46+
async def get_job(
47+
session: ClientSession, num_jobs: int = 1
48+
) -> Optional[List[Dict[str, Any]]]:
4249
"""
43-
Get the job from the queue.
44-
Will continue trying to get a job until one is available.
50+
Get a job from the job-take API.
51+
52+
`num_jobs = 1` will query the legacy singular job-take API.
53+
54+
`num_jobs > 1` will query the batch job-take API.
4555
4656
Args:
47-
session (ClientSession): The async http client to use for the request.
48-
retry (bool): Whether to retry if no job is available.
57+
session (ClientSession): The aiohttp ClientSession to use for the request.
58+
num_jobs (int): The number of jobs to get.
4959
"""
50-
next_job = None
51-
52-
while next_job is None:
53-
try:
54-
async with session.get(_job_get_url()) as response:
55-
if response.status == 204:
56-
log.debug("No content, no job to process.")
57-
if retry is False:
58-
break
59-
continue
60-
61-
if response.status == 400:
62-
log.debug(
63-
"Received 400 status, expected when FlashBoot is enabled."
64-
)
65-
if retry is False:
66-
break
67-
continue
68-
69-
if response.status != 200:
70-
log.error(f"Failed to get job, status code: {response.status}")
71-
if retry is False:
72-
break
73-
continue
74-
75-
received_request = await response.json()
76-
log.debug(f"Request Received | {received_request}")
77-
78-
# Check if the job is valid
79-
job_id = received_request.get("id", None)
80-
job_input = received_request.get("input", None)
81-
82-
if None in [job_id, job_input]:
83-
missing_fields = []
84-
if job_id is None:
85-
missing_fields.append("id")
86-
if job_input is None:
87-
missing_fields.append("input")
88-
89-
log.error(f"Job has missing field(s): {', '.join(missing_fields)}.")
90-
else:
91-
next_job = received_request
92-
93-
except asyncio.TimeoutError:
94-
log.debug("Timeout error, retrying.")
95-
if retry is False:
96-
break
97-
98-
except Exception as err: # pylint: disable=broad-except
99-
err_type = type(err).__name__
100-
err_message = str(err)
101-
err_traceback = traceback.format_exc()
102-
log.error(
103-
f"Failed to get job. | Error Type: {err_type} | Error Message: {err_message}"
104-
)
105-
log.error(f"Traceback: {err_traceback}")
106-
107-
if next_job is None:
108-
log.debug("No job available, waiting for the next one.")
109-
if retry is False:
110-
break
111-
112-
await asyncio.sleep(1)
113-
else:
114-
job_list.add_job(next_job["id"])
115-
log.debug("Request ID added.", next_job["id"])
116-
117-
return next_job
60+
try:
61+
async with session.get(_job_get_url(num_jobs)) as response:
62+
if response.status == 204:
63+
log.debug("No content, no job to process.")
64+
return
65+
66+
if response.status == 400:
67+
log.debug("Received 400 status, expected when FlashBoot is enabled.")
68+
return
69+
70+
if response.status != 200:
71+
log.error(f"Failed to get job, status code: {response.status}")
72+
return
73+
74+
jobs = await response.json()
75+
log.debug(f"Request Received | {jobs}")
76+
77+
# legacy job-take API
78+
if isinstance(jobs, dict):
79+
if "id" not in jobs or "input" not in jobs:
80+
raise Exception("Job has missing field(s): id or input.")
81+
return [jobs]
82+
83+
# batch job-take API
84+
if isinstance(jobs, list):
85+
return jobs
86+
87+
except asyncio.TimeoutError:
88+
log.debug("Timeout error, retrying.")
89+
90+
except Exception as error:
91+
log.error(
92+
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
93+
)
11894

119-
return None
95+
# empty
96+
return []
12097

12198

12299
async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
@@ -164,7 +141,7 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
164141

165142
check_return_size(run_result) # Checks the size of the return body.
166143

167-
except Exception as err: # pylint: disable=broad-except
144+
except Exception as err:
168145
error_info = {
169146
"error_type": str(type(err)),
170147
"error_message": str(err),
@@ -209,7 +186,7 @@ async def run_job_generator(
209186
log.debug(f"Generator output: {output_partial}", job["id"])
210187
yield {"output": output_partial}
211188

212-
except Exception as err: # pylint: disable=broad-except
189+
except Exception as err:
213190
log.error(err, job["id"])
214191
yield {"error": f"handler: {str(err)} \ntraceback: {traceback.format_exc()}"}
215192
finally:

runpod/serverless/modules/rp_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Optional
1616

1717
MAX_MESSAGE_LENGTH = 4096
18-
LOG_LEVELS = ["NOTSET", "DEBUG", "TRACE", "INFO", "WARN", "ERROR"]
18+
LOG_LEVELS = ["NOTSET", "TRACE", "DEBUG", "INFO", "WARN", "ERROR"]
1919

2020

2121
def _validate_log_level(log_level):

runpod/serverless/modules/rp_ping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212

1313
from runpod.http_client import SyncClientSession
1414
from runpod.serverless.modules.rp_logger import RunPodLogger
15-
from runpod.serverless.modules.worker_state import WORKER_ID, Jobs
15+
from runpod.serverless.modules.worker_state import WORKER_ID, JobsQueue
1616
from runpod.version import __version__ as runpod_version
1717

1818
log = RunPodLogger()
19-
jobs = Jobs() # Contains the list of jobs that are currently running.
19+
jobs = JobsQueue() # Contains the list of jobs that are currently running.
2020

2121

2222
class Heartbeat:
@@ -96,7 +96,7 @@ def _send_ping(self):
9696
)
9797

9898
log.debug(
99-
f"Heartbeat Sent | URL: {self.PING_URL} | Status: {result.status_code}"
99+
f"Heartbeat Sent | URL: {result.url} | Status: {result.status_code}"
100100
)
101101

102102
except requests.RequestException as err:

0 commit comments

Comments
 (0)