Skip to content

Commit 7977621

Browse files
committed
remove sleep calls from tests, add cancellation details to async cancellation errors from external activities
1 parent ab65714 commit 7977621

3 files changed

Lines changed: 135 additions & 18 deletions

File tree

temporalio/client.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import temporalio.runtime
5757
import temporalio.service
5858
import temporalio.workflow
59+
from temporalio.activity import ActivityCancellationDetails
5960
from temporalio.service import (
6061
HttpConnectProxyConfig,
6162
KeepAliveConfig,
@@ -5126,9 +5127,12 @@ def __init__(self) -> None:
51265127
class AsyncActivityCancelledError(temporalio.exceptions.TemporalError):
51275128
"""Error that occurs when async activity attempted heartbeat but was cancelled."""
51285129

5129-
def __init__(self) -> None:
5130+
details: Optional[ActivityCancellationDetails] = None
5131+
5132+
def __init__(self, details: Optional[ActivityCancellationDetails] = None) -> None:
51305133
"""Create async activity cancelled error."""
51315134
super().__init__("Activity cancelled")
5135+
self.details = details
51325136

51335137

51345138
class ScheduleAlreadyRunningError(temporalio.exceptions.TemporalError):
@@ -6265,7 +6269,12 @@ async def heartbeat_async_activity(
62656269
timeout=input.rpc_timeout,
62666270
)
62676271
if resp_by_id.cancel_requested or resp_by_id.activity_paused:
6268-
raise AsyncActivityCancelledError()
6272+
raise AsyncActivityCancelledError(
6273+
details=ActivityCancellationDetails(
6274+
cancel_requested=resp_by_id.cancel_requested,
6275+
paused=resp_by_id.activity_paused,
6276+
)
6277+
)
62696278

62706279
else:
62716280
resp = await self._client.workflow_service.record_activity_task_heartbeat(
@@ -6280,7 +6289,12 @@ async def heartbeat_async_activity(
62806289
timeout=input.rpc_timeout,
62816290
)
62826291
if resp.cancel_requested or resp.activity_paused:
6283-
raise AsyncActivityCancelledError()
6292+
raise AsyncActivityCancelledError(
6293+
details=ActivityCancellationDetails(
6294+
cancel_requested=resp.cancel_requested,
6295+
paused=resp.activity_paused,
6296+
)
6297+
)
62846298

62856299
async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
62866300
result = (

tests/helpers/__init__.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from contextlib import closing
66
from datetime import timedelta
7-
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar, cast
7+
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
1010
from temporalio.api.enums.v1 import IndexedValueType
@@ -224,17 +224,49 @@ async def assert_pending_activity_exists_eventually(
224224
) -> PendingActivityInfo:
225225
"""Wait until a pending activity with the given ID exists and return it."""
226226

227-
async def check() -> Optional[PendingActivityInfo]:
228-
desc = await handle.describe()
229-
for act in desc.raw_description.pending_activities:
230-
if act.activity_id == activity_id:
231-
return act
227+
async def check() -> PendingActivityInfo:
228+
act_info = await _get_pending_activity_info(handle, activity_id)
229+
if act_info is not None:
230+
return act_info
232231
raise AssertionError(
233232
f"Activity with ID {activity_id} not found in pending activities"
234233
)
235234

236-
activity_info = await assert_eventually(check, timeout=timeout)
237-
return cast(PendingActivityInfo, activity_info)
235+
return await assert_eventually(check, timeout=timeout)
236+
237+
238+
async def wait_for_next_heartbeat_cycle(
239+
handle: WorkflowHandle,
240+
activity_id: str,
241+
initial_heartbeat_time: Any,
242+
timeout: timedelta = timedelta(seconds=5),
243+
) -> None:
244+
"""Wait for the next heartbeat cycle by monitoring last_heartbeat_time changes."""
245+
246+
async def check_heartbeat_changed() -> None:
247+
current_info = await _get_pending_activity_info(handle, activity_id)
248+
if current_info is None:
249+
raise AssertionError(
250+
f"Activity with ID {activity_id} not found in pending activities"
251+
)
252+
if current_info.last_heartbeat_time == initial_heartbeat_time:
253+
raise AssertionError(
254+
f"Activity with ID {activity_id} has not heartbeated yet"
255+
)
256+
257+
await assert_eventually(check_heartbeat_changed, timeout=timeout)
258+
259+
260+
async def _get_pending_activity_info(
261+
handle: WorkflowHandle,
262+
activity_id: str,
263+
) -> Optional[PendingActivityInfo]:
264+
"""Get pending activity info by ID, or None if not found."""
265+
desc = await handle.describe()
266+
for act in desc.raw_description.pending_activities:
267+
if act.activity_id == activity_id:
268+
return act
269+
return None
238270

239271

240272
async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str):

tests/worker/test_workflow.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@
4949
from temporalio.api.sdk.v1 import EnhancedStackTrace
5050
from temporalio.api.workflowservice.v1 import (
5151
GetWorkflowExecutionHistoryRequest,
52-
PauseActivityRequest,
5352
ResetStickyTaskQueueRequest,
5453
)
5554
from temporalio.bridge.proto.workflow_activation import WorkflowActivation
5655
from temporalio.bridge.proto.workflow_completion import WorkflowActivationCompletion
5756
from temporalio.client import (
57+
AsyncActivityCancelledError,
5858
Client,
5959
RPCError,
6060
RPCStatusCode,
@@ -127,6 +127,7 @@
127127
new_worker,
128128
pause_and_assert,
129129
unpause_and_assert,
130+
wait_for_next_heartbeat_cycle,
130131
workflow_update_exists,
131132
)
132133
from tests.helpers.external_stack_trace import (
@@ -7637,14 +7638,16 @@ async def heartbeat_activity(
76377638
while True:
76387639
try:
76397640
activity.heartbeat()
7640-
# If we are on the second attempt, we have retried due to pause/unpause.
7641-
if activity.info().attempt > 1:
7641+
# If we have heartbeat details, we are on the second attempt, we have retried due to pause/unpause.
7642+
if activity.info().heartbeat_details:
76427643
return activity.cancellation_details()
76437644
await asyncio.sleep(0.1)
76447645
except (CancelledError, asyncio.CancelledError) as err:
76457646
if not catch_err:
76467647
raise err
76477648
return activity.cancellation_details()
7649+
finally:
7650+
activity.heartbeat("finally-complete")
76487651

76497652

76507653
@activity.defn
@@ -7654,14 +7657,16 @@ def sync_heartbeat_activity(
76547657
while True:
76557658
try:
76567659
activity.heartbeat()
7657-
# If we are on the second attempt, we have retried due to pause/unpause.
7658-
if activity.info().attempt > 1:
7660+
# If we have heartbeat details, we are on the second attempt, we have retried due to pause/unpause.
7661+
if activity.info().heartbeat_details:
76597662
return activity.cancellation_details()
76607663
time.sleep(0.1)
76617664
except (CancelledError, asyncio.CancelledError) as err:
76627665
if not catch_err:
76637666
raise err
76647667
return activity.cancellation_details()
7668+
finally:
7669+
activity.heartbeat("finally-complete")
76657670

76667671

76677672
@workflow.defn
@@ -7806,7 +7811,10 @@ async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment):
78067811

78077812
# Wait for next heartbeat to propagate the cancellation. Unpausing before the heartbeat
78087813
# will show activity as unpaused to core. Consequently, it will *not* issue an activity cancel.
7809-
time.sleep(0.3)
7814+
await wait_for_next_heartbeat_cycle(
7815+
handle, activity_info_1.activity_id, activity_info_1.last_heartbeat_time
7816+
)
7817+
78107818
# Unpause activity
78117819
await unpause_and_assert(client, handle, activity_info_1.activity_id)
78127820
# Expect second activity to have started now
@@ -7818,11 +7826,74 @@ async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment):
78187826
# Pause activity then assert it is paused
78197827
await pause_and_assert(client, handle, activity_info_2.activity_id)
78207828
# Wait for next heartbeat to propagate the cancellation.
7821-
time.sleep(0.3)
7829+
await wait_for_next_heartbeat_cycle(
7830+
handle, activity_info_2.activity_id, activity_info_2.last_heartbeat_time
7831+
)
78227832
# Unpause activity
78237833
await unpause_and_assert(client, handle, activity_info_2.activity_id)
78247834

78257835
# Check workflow complete
78267836
result = await handle.result()
78277837
assert result[0] == None
78287838
assert result[1] == None
7839+
7840+
7841+
@activity.defn
7842+
async def external_activity_heartbeat() -> None:
7843+
activity.raise_complete_async()
7844+
7845+
7846+
@workflow.defn
7847+
class ExternalActivityWorkflow:
7848+
@workflow.run
7849+
async def run(self, activity_id: str) -> None:
7850+
await workflow.execute_activity(
7851+
external_activity_heartbeat,
7852+
activity_id=activity_id,
7853+
start_to_close_timeout=timedelta(seconds=10),
7854+
heartbeat_timeout=timedelta(seconds=1),
7855+
retry_policy=RetryPolicy(maximum_attempts=2),
7856+
)
7857+
7858+
7859+
async def test_external_activity_cancellation_details(
7860+
client: Client, env: WorkflowEnvironment
7861+
):
7862+
if env.supports_time_skipping:
7863+
pytest.skip("Time-skipping server does not support pause API yet")
7864+
async with Worker(
7865+
client,
7866+
task_queue=str(uuid.uuid4()),
7867+
workflows=[ExternalActivityWorkflow],
7868+
activities=[external_activity_heartbeat],
7869+
) as worker:
7870+
test_activity_id = f"heartbeat-activity-{uuid.uuid4()}"
7871+
7872+
wf_handle = await client.start_workflow(
7873+
ExternalActivityWorkflow.run,
7874+
test_activity_id,
7875+
id=f"test-external-activity-pause-{uuid.uuid4()}",
7876+
task_queue=worker.task_queue,
7877+
)
7878+
wf_desc = await wf_handle.describe()
7879+
7880+
# Wait for external activity
7881+
activity_info = await assert_pending_activity_exists_eventually(
7882+
wf_handle, test_activity_id
7883+
)
7884+
# Assert not paused
7885+
assert not activity_info.paused
7886+
7887+
external_activity_handle = client.get_async_activity_handle(
7888+
workflow_id=wf_desc.id, run_id=wf_desc.run_id, activity_id=test_activity_id
7889+
)
7890+
7891+
# Pause activity then assert it is paused
7892+
await pause_and_assert(client, wf_handle, activity_info.activity_id)
7893+
7894+
try:
7895+
await external_activity_handle.heartbeat()
7896+
except AsyncActivityCancelledError as err:
7897+
assert err.details == temporalio.activity.ActivityCancellationDetails(
7898+
paused=True
7899+
)

0 commit comments

Comments
 (0)