Skip to content

Commit 40e1837

Browse files
committed
add justification type to all evaluators and tests for that
1 parent fec883e commit 40e1837

15 files changed

Lines changed: 626 additions & 47 deletions

src/uipath/_cli/_evals/_models/_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class EvaluationResultDto(BaseModel):
2222
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
2323

2424
score: float
25-
details: Optional[str] = None
25+
details: Optional[str | BaseModel] = None
2626
evaluation_time: Optional[float] = None
2727

2828
@model_serializer(mode="wrap")

src/uipath/eval/_helpers/coded_evaluators_helpers.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def tool_calls_order_score(
9292
actual_tool_calls_names: Sequence[str],
9393
expected_tool_calls_names: Sequence[str],
9494
strict: bool = False,
95-
) -> tuple[float, str]:
95+
) -> tuple[float, dict[str, Any]]:
9696
"""The function calculates a score based on LCS applied to the order of the tool calls.
9797
9898
It calculates the longest common subsequence between the actual tool calls
@@ -107,18 +107,22 @@ def tool_calls_order_score(
107107
Returns:
108108
tuple[float, str]: Ratio of the LCS length to the number of expected, and the LCS string
109109
"""
110-
justification_template = f"Expected tool calls: {expected_tool_calls_names}\nActual tool calls: {actual_tool_calls_names}"
111-
if not strict:
112-
justification_template += "\nLongest common subsequence: {lcs}"
110+
justification = {
111+
"actual_tool_calls_order": actual_tool_calls_names,
112+
"expected_tool_calls_order": expected_tool_calls_names,
113+
"lcs": [],
114+
}
115+
113116
if expected_tool_calls_names == actual_tool_calls_names:
114-
return 1.0, justification_template.format(lcs=actual_tool_calls_names)
117+
justification["lcs"] = actual_tool_calls_names
118+
return 1.0, justification
115119
elif (
116120
not expected_tool_calls_names
117121
or not actual_tool_calls_names
118122
or strict
119123
and actual_tool_calls_names != expected_tool_calls_names
120124
):
121-
return 0.0, justification_template.format(lcs="")
125+
return 0.0, justification
122126

123127
# Calculate LCS with full DP table for efficient reconstruction
124128
m, n = len(actual_tool_calls_names), len(expected_tool_calls_names)
@@ -147,14 +151,16 @@ def tool_calls_order_score(
147151

148152
lcs.reverse() # Reverse to get correct order
149153
lcs_length = len(lcs)
150-
return lcs_length / n, justification_template.format(lcs=" ".join(lcs))
154+
justification["lcs"] = lcs
155+
return lcs_length / n, justification
151156

152157

153158
def tool_calls_count_score(
154159
actual_tool_calls_count: Mapping[str, int],
155160
expected_tool_calls_count: Mapping[str, tuple[str, int]],
156161
strict: bool = False,
157-
) -> tuple[float, str]:
162+
justification_key: str = "explained_tool_calls_count",
163+
) -> tuple[float, dict[str, Any]]:
158164
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
159165
160166
Args:
@@ -166,34 +172,38 @@ def tool_calls_count_score(
166172
tuple[float, str]: Score based on the number of matches, and the justification.
167173
"""
168174
if not expected_tool_calls_count and not actual_tool_calls_count:
169-
return 1.0, "Both expected and actual tool calls are empty"
175+
return 1.0, {justification_key: "Both expected and actual tool calls are empty"}
170176
elif not expected_tool_calls_count or not actual_tool_calls_count:
171-
return 0.0, "Either expected or actual tool calls are empty"
177+
return 0.0, {
178+
justification_key: "Either expected or actual tool calls are empty"
179+
}
172180

173181
score = 0.0
174-
justifications = []
182+
justifications = {justification_key: {}}
175183
for tool_name, (
176184
expected_comparator,
177185
expected_count,
178186
) in expected_tool_calls_count.items():
179187
actual_count = actual_tool_calls_count.get(tool_name, 0.0)
180188
comparator = f"__{COMPARATOR_MAPPINGS[expected_comparator]}__"
181189
to_add = float(getattr(actual_count, comparator)(expected_count))
182-
justifications.append(
183-
f"{tool_name}: Actual count: {actual_count}, Expected count: {expected_count}, Score: {to_add}"
190+
191+
justifications[justification_key][tool_name] = (
192+
f"Actual: {actual_count}, Expected: {expected_count}, Score: {to_add}"
184193
)
185194
if strict and to_add == 0.0:
186-
return 0.0, justifications[-1]
195+
return 0.0, justifications
187196
score += to_add
188-
return score / len(expected_tool_calls_count), "\n".join(justifications)
197+
return score / len(expected_tool_calls_count), justifications
189198

190199

191200
def tool_calls_args_score(
192201
actual_tool_calls: list[ToolCall],
193202
expected_tool_calls: list[ToolCall],
194203
strict: bool = False,
195204
subset: bool = False,
196-
) -> tuple[float, str]:
205+
justification_key: str = "explained_tool_calls_args",
206+
) -> tuple[float, dict[str, Any]]:
197207
"""Check if the expected tool calls are correctly called, where expected args must be a subset of actual args.
198208
199209
This function does not check the order of the tool calls!
@@ -208,13 +218,15 @@ def tool_calls_args_score(
208218
tuple[float, str]: Score based on the number of matches, and the justification
209219
"""
210220
if not expected_tool_calls and not actual_tool_calls:
211-
return 1.0, "Both expected and actual tool calls are empty"
221+
return 1.0, {justification_key: "Both expected and actual tool calls are empty"}
212222
elif not expected_tool_calls or not actual_tool_calls:
213-
return 0.0, "Either expected or actual tool calls are empty"
223+
return 0.0, {
224+
justification_key: "Either expected or actual tool calls are empty"
225+
}
214226

215227
cnt = 0
216228
visited: set[int] = set()
217-
justifications = []
229+
justifications = {justification_key: {}}
218230
for expected_tool_call in expected_tool_calls:
219231
for idx, call in enumerate(actual_tool_calls):
220232
if call.name == expected_tool_call.name and idx not in visited:
@@ -237,7 +249,9 @@ def tool_calls_args_score(
237249
# Only possible in exact mode when key is missing
238250
args_match = False
239251

240-
justifications.append(f"{call.name}: Args match: {args_match}")
252+
justifications[justification_key][call.name] = (
253+
f"Actual: {call.args}, Expected: {expected_tool_call.args}, Score: {float(args_match)}"
254+
)
241255
if args_match:
242256
cnt += 1
243257
visited.add(idx)
@@ -247,7 +261,7 @@ def tool_calls_args_score(
247261
cnt / len(expected_tool_calls)
248262
if not strict
249263
else float(cnt == len(expected_tool_calls))
250-
), "\n".join(justifications)
264+
), justifications
251265

252266

253267
def tool_output_score(

src/uipath/eval/coded_evaluators/base_evaluator.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from abc import ABC, abstractmethod
77
from collections.abc import Callable
8-
from typing import Any, Generic, TypeVar, get_args
8+
from typing import Any, Generic, TypeVar, Union, cast, get_args
99

1010
from pydantic import BaseModel, ConfigDict, Field, model_validator
1111

@@ -47,11 +47,18 @@ class BaseEvaluatorConfig(BaseModel):
4747
default_evaluation_criteria: BaseEvaluationCriteria | None = None
4848

4949

50+
class BaseEvaluatorJustification(BaseModel):
51+
"""Base class for all evaluator justifications."""
52+
53+
pass
54+
55+
5056
T = TypeVar("T", bound=BaseEvaluationCriteria)
5157
C = TypeVar("C", bound=BaseEvaluatorConfig)
58+
J = TypeVar("J", bound=Union[str, None, BaseEvaluatorJustification])
5259

5360

54-
class BaseEvaluator(BaseModel, Generic[T, C], ABC):
61+
class BaseEvaluator(BaseModel, Generic[T, C, J], ABC):
5562
"""Abstract base class for all evaluators."""
5663

5764
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -61,6 +68,9 @@ class BaseEvaluator(BaseModel, Generic[T, C], ABC):
6168
evaluation_criteria_type: type[T] = Field(
6269
description="The type used for evaluation criteria validation and creation"
6370
)
71+
justification_type: type[J] = Field(
72+
description="The type used for justification validation and creation"
73+
)
6474
evaluator_config: C = Field(
6575
exclude=True, description="The validated config object instance"
6676
)
@@ -101,6 +111,10 @@ def validate_model(cls, values: Any) -> Any:
101111
config_type = cls._extract_config_type()
102112
values["config_type"] = config_type
103113

114+
# Always extract and set justification_type
115+
justification_type = cls._extract_justification_type()
116+
values["justification_type"] = justification_type
117+
104118
# Validate and create the config object if config dict is provided
105119
if config_dict := values.get("config"):
106120
try:
@@ -182,6 +196,33 @@ def _extract_config_type(cls) -> type[BaseEvaluatorConfig]:
182196
f"Ensure the class properly inherits from BaseEvaluator with correct Generic parameters."
183197
)
184198

199+
@classmethod
200+
def _extract_justification_type(cls) -> type[J]:
201+
"""Extract the justification type from Pydantic model fields.
202+
203+
Returns:
204+
The justification type (str, None, or BaseEvaluatorJustification subclass)
205+
"""
206+
if cls.__name__ == "BaseEvaluator":
207+
return cast(type[J], type(None))
208+
209+
if hasattr(cls, "model_fields") and "justification_type" in cls.model_fields:
210+
field_info = cls.model_fields["justification_type"]
211+
if hasattr(field_info, "annotation"):
212+
annotation = field_info.annotation
213+
if args := get_args(annotation):
214+
justification_type = args[0]
215+
# Support str, None, or BaseEvaluatorJustification subclasses
216+
if justification_type is str or justification_type is type(None):
217+
return cast(type[J], justification_type)
218+
elif isinstance(justification_type, type) and issubclass(
219+
justification_type, BaseEvaluatorJustification
220+
):
221+
return cast(type[J], justification_type)
222+
223+
# Default to None if we can't determine the type
224+
return cast(type[J], type(None))
225+
185226
def validate_evaluation_criteria(self, criteria: Any) -> T:
186227
"""Validate and convert input to the correct evaluation criteria type.
187228
@@ -213,6 +254,64 @@ def validate_evaluation_criteria(self, criteria: Any) -> T:
213254
f"Cannot convert {type(criteria)} to {self.evaluation_criteria_type}: {e}"
214255
) from e
215256

257+
def validate_justification(self, justification: Any) -> J:
258+
"""Validate and convert input to the correct justification type.
259+
260+
Args:
261+
justification: The justification to validate (str, None, dict, BaseEvaluatorJustification, or other)
262+
263+
Returns:
264+
The validated justification of the correct type
265+
"""
266+
# The key insight: J is constrained to be one of str, None, or BaseEvaluatorJustification
267+
# At instantiation time, J gets bound to exactly one of these types
268+
# We need to handle each case and ensure the return matches the bound type
269+
270+
# Handle None type - when J is bound to None (the literal None type)
271+
if self.justification_type is type(None):
272+
# When J is None, we can only return None
273+
return cast(J, justification if justification is None else None)
274+
275+
# Handle str type - when J is bound to str
276+
if self.justification_type is str:
277+
# When J is str, we must return a str
278+
if justification is None:
279+
return cast(J, "")
280+
return cast(J, str(justification))
281+
282+
# Handle BaseEvaluatorJustification subclasses - when J is bound to a specific subclass
283+
if isinstance(self.justification_type, type) and issubclass(
284+
self.justification_type, BaseEvaluatorJustification
285+
):
286+
# When J is a BaseEvaluatorJustification subclass, we must return that type
287+
if justification is None:
288+
raise ValueError(
289+
f"None is not allowed for justification type {self.justification_type}"
290+
)
291+
292+
if isinstance(justification, self.justification_type):
293+
return cast(J, justification)
294+
elif isinstance(justification, dict):
295+
return cast(J, self.justification_type.model_validate(justification))
296+
elif hasattr(justification, "__dict__"):
297+
return cast(
298+
J, self.justification_type.model_validate(justification.__dict__)
299+
)
300+
else:
301+
try:
302+
return cast(
303+
J, self.justification_type.model_validate(justification)
304+
)
305+
except Exception as e:
306+
raise ValueError(
307+
f"Cannot convert {type(justification)} to {self.justification_type}: {e}"
308+
) from e
309+
310+
# Fallback: try to return as-is or raise error
311+
raise ValueError(
312+
f"Unsupported justification type {self.justification_type} for input {type(justification)}"
313+
)
314+
216315
@classmethod
217316
def get_evaluation_criteria_schema(cls) -> dict[str, Any]:
218317
"""Get the JSON schema for the evaluation criteria type.

src/uipath/eval/coded_evaluators/contains_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class ContainsEvaluatorConfig(BaseEvaluatorConfig):
2020

2121

2222
class ContainsEvaluator(
23-
BaseEvaluator[ContainsEvaluationCriteria, ContainsEvaluatorConfig]
23+
BaseEvaluator[ContainsEvaluationCriteria, ContainsEvaluatorConfig, None]
2424
):
2525
"""Evaluator that checks if the actual output contains the expected output.
2626

src/uipath/eval/coded_evaluators/exact_match_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ExactMatchEvaluatorConfig(OutputEvaluatorConfig):
1616
negated: bool = False
1717

1818

19-
class ExactMatchEvaluator(OutputEvaluator[ExactMatchEvaluatorConfig]):
19+
class ExactMatchEvaluator(OutputEvaluator[ExactMatchEvaluatorConfig, type(None)]):
2020
"""Evaluator that performs exact structural matching between expected and actual outputs.
2121
2222
This evaluator returns True if the actual output exactly matches the expected output

src/uipath/eval/coded_evaluators/json_similarity_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class JsonSimilarityEvaluatorConfig(OutputEvaluatorConfig):
2020
target_output_key: str = Field(default="*", frozen=True, exclude=True)
2121

2222

23-
class JsonSimilarityEvaluator(OutputEvaluator[JsonSimilarityEvaluatorConfig]):
23+
class JsonSimilarityEvaluator(OutputEvaluator[JsonSimilarityEvaluatorConfig, str]):
2424
"""Deterministic evaluator that scores structural JSON similarity between expected and actual output.
2525
2626
Compares expected versus actual JSON-like structures and returns a
@@ -51,9 +51,10 @@ async def evaluate(
5151
self._get_expected_output(evaluation_criteria),
5252
self._get_actual_output(agent_execution),
5353
)
54+
validated_justification = self.validate_justification(justification)
5455
return NumericEvaluationResult(
5556
score=score,
56-
details=justification,
57+
details=validated_justification,
5758
)
5859

5960
def _compare_json(self, expected: Any, actual: Any) -> tuple[float, str]:

src/uipath/eval/coded_evaluators/llm_as_judge_evaluator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class BaseLLMJudgeEvaluatorConfig(BaseEvaluatorConfig):
3636
C = TypeVar("C", bound=BaseLLMJudgeEvaluatorConfig)
3737

3838

39-
class LLMJudgeMixin(BaseEvaluator[T, C]):
39+
class LLMJudgeMixin(BaseEvaluator[T, C, str]):
4040
"""Mixin that provides common LLM judge functionality."""
4141

4242
system_prompt: str = LLMJudgePromptTemplates.LLM_JUDGE_SYSTEM_PROMPT
@@ -94,10 +94,13 @@ async def evaluate(
9494
)
9595

9696
llm_response = await self._get_llm_response(evaluation_prompt)
97+
validated_justification = self.validate_justification(
98+
llm_response.justification
99+
)
97100

98101
return NumericEvaluationResult(
99102
score=round(llm_response.score / 100.0, 2),
100-
details=llm_response.justification,
103+
details=validated_justification,
101104
)
102105

103106
def _create_evaluation_prompt(

src/uipath/eval/coded_evaluators/llm_judge_output_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class LLMJudgeStrictJSONSimilarityOutputEvaluatorConfig(LLMJudgeOutputEvaluatorC
4141

4242

4343
class BaseLLMOutputEvaluator(
44-
OutputEvaluator[OC],
44+
OutputEvaluator[OC, str],
4545
LLMJudgeMixin[OutputEvaluationCriteria, OC],
4646
):
4747
"""Base class for LLM judge output evaluators that contains all shared functionality.

0 commit comments

Comments
 (0)