diff --git a/pydantic_evals/pydantic_evals/evaluators/common.py b/pydantic_evals/pydantic_evals/evaluators/common.py index 23aa7c03b..f1a83aa78 100644 --- a/pydantic_evals/pydantic_evals/evaluators/common.py +++ b/pydantic_evals/pydantic_evals/evaluators/common.py @@ -8,7 +8,7 @@ from ..otel.span_tree import SpanQuery from .context import EvaluatorContext -from .evaluator import EvaluationReason, Evaluator, EvaluatorOutput +from .evaluator import EvaluationReason, EvaluationScalar, Evaluator, EvaluatorOutput __all__ = ( 'Equals', @@ -164,6 +164,7 @@ class LLMJudge(Evaluator[object, object, object]): rubric: str model: models.Model | models.KnownModelName | None = None include_input: bool = False + return_score: bool = False async def evaluate( self, @@ -177,7 +178,8 @@ async def evaluate( from .llm_as_a_judge import judge_output grading_output = await judge_output(ctx.output, self.rubric, self.model) - return EvaluationReason(value=grading_output.pass_, reason=grading_output.reason) + evaluation_value: EvaluationScalar = grading_output.score if self.return_score else grading_output.pass_ + return EvaluationReason(value=evaluation_value, reason=grading_output.reason) def build_serialization_arguments(self): result = super().build_serialization_arguments()