Skip to content

Commit 0312217

Browse files
committed
add tool call output evaluator, fix several issues and inconsistencies, add eval helper tests
1 parent d4b53e5 commit 0312217

8 files changed

Lines changed: 1299 additions & 109 deletions

File tree

src/uipath/eval/_helpers/coded_evaluators_helpers.py

Lines changed: 130 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from opentelemetry.sdk.trace import ReadableSpan
88

9-
from uipath.eval.models.models import ToolCall
9+
from ..models import ToolCall, ToolOutput
1010

1111
COMPARATOR_MAPPINGS = {
1212
">": "gt",
@@ -63,14 +63,14 @@ def extract_tool_calls(spans: Sequence[ReadableSpan]) -> list[ToolCall]:
6363
else:
6464
arguments = {}
6565
tool_calls.append(ToolCall(name=str(tool_name), args=arguments))
66-
except json.JSONDecodeError:
67-
# Handle case where input.value is not valid JSON
66+
except (json.JSONDecodeError, SyntaxError, ValueError):
67+
# Handle case where input.value is not valid JSON/Python syntax
6868
tool_calls.append(ToolCall(name=str(tool_name), args={}))
6969

7070
return tool_calls
7171

7272

73-
def extract_tool_calls_outputs(spans: Sequence[ReadableSpan]) -> list[dict[str, Any]]:
73+
def extract_tool_calls_outputs(spans: Sequence[ReadableSpan]) -> list[ToolOutput]:
7474
"""Extract the outputs of the tool calls from execution spans.
7575
7676
Args:
@@ -83,7 +83,10 @@ def extract_tool_calls_outputs(spans: Sequence[ReadableSpan]) -> list[dict[str,
8383
for span in spans:
8484
if span.attributes and (tool_name := span.attributes.get("tool.name")):
8585
tool_calls_outputs.append(
86-
{"name": tool_name, "output": span.attributes.get("output.value", {})}
86+
ToolOutput(
87+
name=str(tool_name),
88+
output=span.attributes.get("output.value", ""),
89+
)
8790
)
8891
return tool_calls_outputs
8992

@@ -102,26 +105,30 @@ def tool_calls_order_score(
102105
Args:
103106
actual_tool_calls_names: List of tool names in the actual order
104107
expected_tool_calls_names: List of tool names in the expected order
105-
strict: If True, the function will return 0 if the actual calls do not match the expected calls
108+
strict: If True, the function will return 0 if the actual calls do not match the expected calls exactly
106109
107110
Returns:
108-
tuple[float, str]: Ratio of the LCS length to the number of expected, and the LCS string
111+
tuple[float, dict]: Ratio of the LCS length to the number of expected, and the justification dict
109112
"""
110113
justification = {
111-
"actual_tool_calls_order": actual_tool_calls_names,
112-
"expected_tool_calls_order": expected_tool_calls_names,
114+
"actual_tool_calls_order": list(actual_tool_calls_names),
115+
"expected_tool_calls_order": list(expected_tool_calls_names),
113116
"lcs": [],
114117
}
115118

119+
# Handle empty cases
120+
if not expected_tool_calls_names and not actual_tool_calls_names:
121+
return 1.0, justification
122+
elif not expected_tool_calls_names or not actual_tool_calls_names:
123+
return 0.0, justification
124+
125+
# Handle exact match
116126
if expected_tool_calls_names == actual_tool_calls_names:
117-
justification["lcs"] = actual_tool_calls_names
127+
justification["lcs"] = list(actual_tool_calls_names)
118128
return 1.0, justification
119-
elif (
120-
not expected_tool_calls_names
121-
or not actual_tool_calls_names
122-
or strict
123-
and actual_tool_calls_names != expected_tool_calls_names
124-
):
129+
130+
# Handle strict mode - only perfect matches allowed
131+
if strict:
125132
return 0.0, justification
126133

127134
# Calculate LCS with full DP table for efficient reconstruction
@@ -161,21 +168,28 @@ def tool_calls_count_score(
161168
strict: bool = False,
162169
justification_key: str = "explained_tool_calls_count",
163170
) -> tuple[float, dict[str, Any]]:
164-
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
171+
"""Check if the expected tool call counts match the actual tool call counts.
165172
166173
Args:
167-
actual_tool_calls_count: List of actual tool calls count.
168-
expected_tool_calls_count: List of expected tool calls count.
174+
actual_tool_calls_count: Mapping of tool names to their actual call counts.
175+
expected_tool_calls_count: Mapping of tool names to expected (comparator, count) tuples.
169176
strict: If True, the function will return 0 if not all expected tool calls are matched.
177+
justification_key: Key to use for the justification in the returned dict.
170178
171179
Returns:
172-
tuple[float, str]: Score based on the number of matches, and the justification.
180+
tuple[float, dict]: Score based on the number of matches, and the justification dict.
173181
"""
174182
if not expected_tool_calls_count and not actual_tool_calls_count:
175-
return 1.0, {justification_key: "Both expected and actual tool calls are empty"}
183+
return 1.0, {
184+
justification_key: {
185+
"_result": "Both expected and actual tool calls are empty"
186+
}
187+
}
176188
elif not expected_tool_calls_count or not actual_tool_calls_count:
177189
return 0.0, {
178-
justification_key: "Either expected or actual tool calls are empty"
190+
justification_key: {
191+
"_result": "Either expected or actual tool calls are empty"
192+
}
179193
}
180194

181195
score = 0.0
@@ -192,7 +206,13 @@ def tool_calls_count_score(
192206
f"Actual: {actual_count}, Expected: {expected_count}, Score: {to_add}"
193207
)
194208
if strict and to_add == 0.0:
195-
return 0.0, justifications
209+
# When strict is True, if the actual count does not match the expected count, return 0
210+
# The justification should only include the breaching tool name
211+
return 0.0, {
212+
justification_key: {
213+
tool_name: justifications[justification_key][tool_name]
214+
}
215+
}
196216
score += to_add
197217
return score / len(expected_tool_calls_count), justifications
198218

@@ -204,32 +224,46 @@ def tool_calls_args_score(
204224
subset: bool = False,
205225
justification_key: str = "explained_tool_calls_args",
206226
) -> tuple[float, dict[str, Any]]:
207-
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
227+
"""Check if the expected tool calls are correctly called with matching arguments.
208228
209229
This function does not check the order of the tool calls!
210230
211-
Arguments:
212-
actual_tool_calls (list[Dict[str, Any]]): List of actual tool calls in the format of {"name": str, "args": Dict[str, Any]}
213-
expected_tool_calls (list[Dict[str, Any]]): List of expected tool calls in the format of {"name": str, "args": Dict[str, Any]}
214-
strict (bool): If True, the function will return 0 if not all expected tool calls are matched
215-
subset (bool): If True, the function will check if the expected args are a subset of the actual args
231+
Args:
232+
actual_tool_calls: List of actual tool calls with their arguments.
233+
expected_tool_calls: List of expected tool calls with their arguments.
234+
strict: If True, the function will return 0 if not all expected tool calls are matched.
235+
subset: If True, the function will check if the expected args are a subset of the actual args.
236+
justification_key: Key to use for the justification in the returned dict.
216237
217238
Returns:
218-
tuple[float, str]: Score based on the number of matches, and the justification
239+
tuple[float, dict]: Score based on the number of matches, and the justification dict.
219240
"""
220241
if not expected_tool_calls and not actual_tool_calls:
221-
return 1.0, {justification_key: "Both expected and actual tool calls are empty"}
242+
return 1.0, {
243+
justification_key: {
244+
"_result": "Both expected and actual tool calls are empty"
245+
}
246+
}
222247
elif not expected_tool_calls or not actual_tool_calls:
223248
return 0.0, {
224-
justification_key: "Either expected or actual tool calls are empty"
249+
justification_key: {
250+
"_result": "Either expected or actual tool calls are empty"
251+
}
225252
}
226253

227254
cnt = 0
228255
visited: set[int] = set()
229256
justifications = {justification_key: {}}
257+
tool_counters: dict[str, int] = {}
258+
230259
for expected_tool_call in expected_tool_calls:
231260
for idx, call in enumerate(actual_tool_calls):
232261
if call.name == expected_tool_call.name and idx not in visited:
262+
# Get or initialize counter for this tool name
263+
tool_counters[call.name] = tool_counters.get(call.name, 0)
264+
tool_key = f"{call.name}_{tool_counters[call.name]}"
265+
tool_counters[call.name] += 1
266+
233267
# Check arguments based on mode
234268
if subset:
235269
# Subset mode: safely check if all expected args exist and match
@@ -249,13 +283,15 @@ def tool_calls_args_score(
249283
# Only possible in exact mode when key is missing
250284
args_match = False
251285

252-
justifications[justification_key][call.name] = (
286+
justifications[justification_key][tool_key] = (
253287
f"Actual: {call.args}, Expected: {expected_tool_call.args}, Score: {float(args_match)}"
254288
)
255289
if args_match:
256290
cnt += 1
257291
visited.add(idx)
258292
break
293+
# In case of mismatch, DON'T add to visited in non-strict mode
294+
# so this actual tool call can be matched against other expected calls
259295

260296
return (
261297
cnt / len(expected_tool_calls)
@@ -264,11 +300,12 @@ def tool_calls_args_score(
264300
), justifications
265301

266302

267-
def tool_output_score(
268-
actual_tool_calls_outputs: list[dict[str, Any]],
269-
expected_tool_calls_outputs: list[dict[str, Any]],
303+
def tool_calls_output_score(
304+
actual_tool_calls_outputs: list[ToolOutput],
305+
expected_tool_calls_outputs: list[ToolOutput],
270306
strict: bool = False,
271-
) -> float:
307+
justification_key: str = "explained_tool_calls_outputs",
308+
) -> tuple[float, dict[str, Any]]:
272309
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
273310
274311
Args:
@@ -280,32 +317,71 @@ def tool_output_score(
280317
tuple[float, str]: Score based on the number of matches, and the justification.
281318
"""
282319
if not expected_tool_calls_outputs and not actual_tool_calls_outputs:
283-
return 1.0
284-
elif (
285-
not expected_tool_calls_outputs
286-
or not actual_tool_calls_outputs
287-
or strict
288-
and actual_tool_calls_outputs != expected_tool_calls_outputs
289-
):
290-
return 0.0
320+
return 1.0, {
321+
justification_key: {
322+
"_result": "Both expected and actual tool calls outputs are empty"
323+
}
324+
}
325+
elif not expected_tool_calls_outputs or not actual_tool_calls_outputs:
326+
return 0.0, {
327+
justification_key: {
328+
"_result": "Either expected or actual tool calls outputs are empty"
329+
}
330+
}
291331

292332
cnt = 0.0
333+
justifications = {justification_key: {}}
334+
visited: set[int] = set()
335+
tool_counters: dict[str, int] = {}
336+
293337
for expected_tool_call_output in expected_tool_calls_outputs:
294-
for actual_tool_call_output in actual_tool_calls_outputs:
295-
if actual_tool_call_output.get("name") == expected_tool_call_output.get(
296-
"name"
297-
):
298-
if json.loads(actual_tool_call_output.get("output", "{}")).get(
299-
"content"
300-
) == expected_tool_call_output.get("output"):
338+
matched = False
339+
340+
# Look through ALL actual tool calls to find a match
341+
for idx, actual_tool_call_output in enumerate(actual_tool_calls_outputs):
342+
if idx in visited:
343+
continue
344+
if actual_tool_call_output.name == expected_tool_call_output.name:
345+
# Get or initialize counter for this tool name
346+
tool_counters[actual_tool_call_output.name] = tool_counters.get(
347+
actual_tool_call_output.name, 0
348+
)
349+
tool_key = f"{actual_tool_call_output.name}_{tool_counters[actual_tool_call_output.name]}"
350+
tool_counters[actual_tool_call_output.name] += 1
351+
352+
justifications[justification_key][tool_key] = (
353+
f"Actual: {actual_tool_call_output.output}, Expected: {expected_tool_call_output.output}, Score: {float(actual_tool_call_output.output == expected_tool_call_output.output)}"
354+
)
355+
356+
if actual_tool_call_output.output == expected_tool_call_output.output:
357+
# Perfect match found
301358
cnt += 1.0
359+
visited.add(idx)
360+
matched = True
361+
break
302362
elif strict:
303-
return 0.0
363+
# In strict mode, any mismatch returns 0 immediately
364+
return 0.0, {
365+
justification_key: {
366+
tool_key: justifications[justification_key][tool_key]
367+
}
368+
}
369+
# In non-strict mode with mismatch, continue looking for perfect match
370+
# DON'T add to visited, DON'T break
371+
372+
# If no match found and we're in strict mode, return 0
373+
if not matched and strict:
374+
return 0.0, {
375+
justification_key: {
376+
"_result": f"No matching actual tool call found for expected {expected_tool_call_output.name}"
377+
}
378+
}
379+
304380
return (
305381
cnt / len(expected_tool_calls_outputs)
306382
if not strict
307383
else float(cnt == len(expected_tool_calls_outputs))
308-
)
384+
), justifications
309385

310386

311387
def trace_to_str(agent_trace: Sequence[ReadableSpan]) -> str:

0 commit comments

Comments
 (0)