Skip to content

Commit 40e75f6

Browse files
Expose Nexus Endpoint in a Nexus Operation Handler (#1437)
Expose Nexus Endpoint in a Nexus Operation Handler
1 parent 0916177 commit 40e75f6

6 files changed

Lines changed: 48 additions & 13 deletions

File tree

temporalio/bridge/proto/nexus/nexus_pb2.py

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

temporalio/bridge/proto/nexus/nexus_pb2.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ class NexusTask(google.protobuf.message.Message):
240240
TASK_FIELD_NUMBER: builtins.int
241241
CANCEL_TASK_FIELD_NUMBER: builtins.int
242242
REQUEST_DEADLINE_FIELD_NUMBER: builtins.int
243+
ENDPOINT_FIELD_NUMBER: builtins.int
243244
@property
244245
def task(
245246
self,
@@ -265,13 +266,18 @@ class NexusTask(google.protobuf.message.Message):
265266
Only set when variant is `task` and the header was present with a valid value.
266267
Represented as an absolute timestamp.
267268
"""
269+
endpoint: builtins.str
270+
"""The endpoint this request was addressed to. Extracted from the request for convenient access.
271+
Only set when variant is `task`.
272+
"""
268273
def __init__(
269274
self,
270275
*,
271276
task: temporalio.api.workflowservice.v1.request_response_pb2.PollNexusTaskQueueResponse
272277
| None = ...,
273278
cancel_task: global___CancelNexusTask | None = ...,
274279
request_deadline: google.protobuf.timestamp_pb2.Timestamp | None = ...,
280+
endpoint: builtins.str = ...,
275281
) -> None: ...
276282
def HasField(
277283
self,
@@ -291,6 +297,8 @@ class NexusTask(google.protobuf.message.Message):
291297
field_name: typing_extensions.Literal[
292298
"cancel_task",
293299
b"cancel_task",
300+
"endpoint",
301+
b"endpoint",
294302
"request_deadline",
295303
b"request_deadline",
296304
"task",

temporalio/bridge/sdk-core

temporalio/nexus/_operation_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ class Info:
7979
Retrieved inside a Nexus operation handler via :py:func:`info`.
8080
"""
8181

82+
endpoint: str
83+
"""The endpoint this Nexus request was addressed to."""
84+
8285
namespace: str
8386
"""The namespace of the worker handling this Nexus operation."""
8487

temporalio/worker/_nexus.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def raise_from_exception_queue() -> NoReturn:
140140
headers=dict(task.request.header),
141141
task_cancellation=task_cancellation,
142142
request_deadline=request_deadline,
143+
endpoint=nexus_task.endpoint,
143144
)
144145
)
145146
self._running_tasks[task.task_token] = _RunningNexusTask(
@@ -154,6 +155,7 @@ async def raise_from_exception_queue() -> NoReturn:
154155
headers=dict(task.request.header),
155156
task_cancellation=task_cancellation,
156157
request_deadline=request_deadline,
158+
endpoint=nexus_task.endpoint,
157159
)
158160
)
159161
self._running_tasks[task.task_token] = _RunningNexusTask(
@@ -224,6 +226,7 @@ async def _handle_cancel_operation_task(
224226
headers: Mapping[str, str],
225227
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
226228
request_deadline: datetime | None,
229+
endpoint: str,
227230
) -> None:
228231
"""Handle a cancel operation task.
229232
@@ -244,7 +247,11 @@ async def _handle_cancel_operation_task(
244247
request_deadline=request_deadline,
245248
)
246249
temporalio.nexus._operation_context._TemporalCancelOperationContext(
247-
info=lambda: Info(namespace=self._namespace, task_queue=self._task_queue),
250+
info=lambda: Info(
251+
endpoint=endpoint,
252+
namespace=self._namespace,
253+
task_queue=self._task_queue,
254+
),
248255
nexus_context=ctx,
249256
client=self._client,
250257
_runtime_metric_meter=self._metric_meter,
@@ -293,6 +300,7 @@ async def _handle_start_operation_task(
293300
headers: Mapping[str, str],
294301
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
295302
request_deadline: datetime | None,
303+
endpoint: str,
296304
) -> None:
297305
"""Handle a start operation task.
298306
@@ -302,7 +310,11 @@ async def _handle_start_operation_task(
302310
try:
303311
try:
304312
start_response = await self._start_operation(
305-
start_request, headers, task_cancellation, request_deadline
313+
start_request,
314+
headers,
315+
task_cancellation,
316+
request_deadline,
317+
endpoint,
306318
)
307319
except asyncio.CancelledError:
308320
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
@@ -346,6 +358,7 @@ async def _start_operation(
346358
headers: Mapping[str, str],
347359
cancellation: nexusrpc.handler.OperationTaskCancellation,
348360
request_deadline: datetime | None,
361+
endpoint: str,
349362
) -> temporalio.api.nexus.v1.StartOperationResponse:
350363
"""Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.
351364
@@ -375,7 +388,11 @@ async def _start_operation(
375388
temporalio.nexus._operation_context._TemporalStartOperationContext(
376389
nexus_context=ctx,
377390
client=self._client,
378-
info=lambda: Info(namespace=self._namespace, task_queue=self._task_queue),
391+
info=lambda: Info(
392+
endpoint=endpoint,
393+
namespace=self._namespace,
394+
task_queue=self._task_queue,
395+
),
379396
_runtime_metric_meter=self._metric_meter,
380397
_worker_shutdown_event=self._worker_shutdown_event,
381398
).set()

tests/nexus/test_workflow_caller.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,11 @@ async def get_info(
703703
self, _ctx: StartOperationContext, _input: None
704704
) -> dict[str, str]:
705705
info = nexus.info()
706-
return {"namespace": info.namespace, "task_queue": info.task_queue}
706+
return {
707+
"endpoint": info.endpoint,
708+
"namespace": info.namespace,
709+
"task_queue": info.task_queue,
710+
}
707711

708712

709713
@workflow.defn
@@ -733,6 +737,9 @@ async def test_nexus_info_includes_namespace(client: Client, env: WorkflowEnviro
733737
id=str(uuid.uuid4()),
734738
task_queue=task_queue,
735739
)
740+
if not env.supports_time_skipping:
741+
# Time-skipping server doesn't send the endpoint yet.
742+
assert result["endpoint"] == endpoint_name
736743
assert result["namespace"] == client.namespace
737744
assert result["task_queue"] == task_queue
738745

0 commit comments

Comments
 (0)