-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_issue_5282.py
More file actions
241 lines (191 loc) · 7.73 KB
/
test_issue_5282.py
File metadata and controls
241 lines (191 loc) · 7.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Regression tests for google/adk-python#5282.
Runner._run_node_async (the dispatch path for Workflow / BaseNode roots)
dispatches on_user_message_callback, before_run_callback, and
on_event_callback, but does not dispatch run_after_run_callback
(runners.py:427 TODO).
Three tests:
(a) Baseline — pre-run and event hooks DO fire on a Workflow root.
(b) Regression anchor — after_run_callback does NOT fire (strict xfail).
Remove the xfail when the TODO at runners.py:427 is wired.
(c) Workaround proof — Runner subclass wrapping run_async restores
after_run_callback dispatch without touching ADK source.
Concurrency note for the WorkaroundRunner pattern: under concurrent
run_async calls on a shared Runner, the _last_ic stash should live on a
contextvars.ContextVar rather than self. The instance attribute is safe
for single-invocation tests but will race under concurrent load.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import AsyncGenerator
from typing import Optional
from google.adk.agents.invocation_context import InvocationContext
from google.adk.apps.app import App
from google.adk.events.event import Event
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.workflow import Workflow
from google.genai import types
import pytest
APP_NAME = "issue_5282_repro"
USER_ID = "u1"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@dataclass
class CallbackCounts:
on_user_message_callback: int = 0
before_run_callback: int = 0
on_event_callback: int = 0
after_run_callback: int = 0
class TracerPlugin(BasePlugin):
"""Counts every Plugin lifecycle callback the Runner dispatches."""
__test__ = False
def __init__(self) -> None:
super().__init__(name="tracer")
self.counts = CallbackCounts()
async def on_user_message_callback(
self,
*,
invocation_context: InvocationContext,
user_message: types.Content,
) -> Optional[types.Content]:
self.counts.on_user_message_callback += 1
return None
async def before_run_callback(
self, *, invocation_context: InvocationContext
) -> Optional[types.Content]:
self.counts.before_run_callback += 1
return None
async def on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
self.counts.on_event_callback += 1
return None
async def after_run_callback(
self, *, invocation_context: InvocationContext
) -> None:
self.counts.after_run_callback += 1
return None
async def _terminal_node(ctx) -> Event:
"""Minimal terminal node yielding a content-bearing Event.
Content (not just state) ensures _consume_event_queue runs the
on_event_callback path — the canonical case the plugin hook targets.
"""
return Event(
content=types.Content(
parts=[types.Part(text="done")],
role="model",
)
)
def _build_runner(
plugin: TracerPlugin, *, runner_cls: type[Runner] = Runner
) -> Runner:
workflow = Workflow(
name="Issue5282Repro", edges=[("START", _terminal_node)]
)
app = App(name=APP_NAME, root_agent=workflow, plugins=[plugin])
return runner_cls(
app_name=APP_NAME,
app=app,
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
async def _drive_one_invocation(runner: Runner) -> None:
session = await runner.session_service.create_session(
app_name=APP_NAME, user_id=USER_ID
)
async for _ in runner.run_async(
user_id=USER_ID,
session_id=session.id,
new_message=types.Content(
parts=[types.Part(text="hi")], role="user"
),
):
pass
# ---------------------------------------------------------------------------
# Workaround: Runner subclass dispatching run_after_run_callback post-drain
# ---------------------------------------------------------------------------
class WorkaroundRunner(Runner):
"""Interim workaround for #5282.
Wraps run_async to dispatch plugin_manager.run_after_run_callback once
the inner generator drains. Captures the active InvocationContext via
_new_invocation_context (called once at runners.py:446).
Drop this class when the runners.py:427 TODO is resolved — the stock
Runner will dispatch after_run_callback natively, and
test_workflow_root_after_run_callback_not_dispatched will flip green
as the signal.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._last_ic: Optional[InvocationContext] = None
def _new_invocation_context(self, session, **kwargs) -> InvocationContext:
ic = super()._new_invocation_context(session, **kwargs)
self._last_ic = ic
return ic
async def run_async(self, **kwargs) -> AsyncGenerator[Event, None]:
async for event in super().run_async(**kwargs):
yield event
ic = self._last_ic
if ic is not None:
await ic.plugin_manager.run_after_run_callback(
invocation_context=ic
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_workflow_root_dispatches_pre_run_and_event_hooks():
"""Baseline: pre-run and event hooks fire on a Workflow (BaseNode) root."""
plugin = TracerPlugin()
runner = _build_runner(plugin)
await _drive_one_invocation(runner)
assert plugin.counts.on_user_message_callback == 1
assert plugin.counts.before_run_callback == 1
assert plugin.counts.on_event_callback >= 1, (
"on_event_callback should fire via _consume_event_queue "
"(runners.py:619) for the content-bearing terminal event"
)
@pytest.mark.xfail(
reason=(
"#5282: runners.py:427 TODO — _run_node_async does not dispatch "
"plugin_manager.run_after_run_callback on the BaseNode path. "
"Remove this xfail when the TODO lands."
),
strict=True,
)
@pytest.mark.asyncio
async def test_workflow_root_after_run_callback_not_dispatched():
"""Regression anchor: stock Runner does NOT fire after_run_callback.
Strict xfail — passes (as xfail) while the bug exists, fails loudly if
after_run_callback starts firing unexpectedly. When the fix lands, delete
the @xfail decorator and the test becomes a green regression guard.
"""
plugin = TracerPlugin()
runner = _build_runner(plugin)
await _drive_one_invocation(runner)
assert plugin.counts.after_run_callback == 1
@pytest.mark.asyncio
async def test_workaround_runner_dispatches_after_run_callback():
"""WorkaroundRunner restores after_run_callback without touching ADK source."""
plugin = TracerPlugin()
runner = _build_runner(plugin, runner_cls=WorkaroundRunner)
await _drive_one_invocation(runner)
assert plugin.counts.on_user_message_callback == 1
assert plugin.counts.before_run_callback == 1
assert plugin.counts.on_event_callback >= 1
assert plugin.counts.after_run_callback == 1