Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions apps/query-eval/queryeval/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,22 @@ def do_eval(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalRe
)
metrics.plan_diff_count = len(plan_diff)

# Evaluate doc retrieval
if not query.expected_docs:
console.print("[yellow]:construction: No expected document list found, skipping.. ")
elif not result.retrieved_docs:
console.print("[yellow]:construction: No computed document list found, skipping.. ")
else:
if query.expected_docs.issubset(result.retrieved_docs):
console.print("[green]✔ Documents retrieved match")
else:
console.print("[red]:x: Document retrieval mismatch")
console.print(f"Missing docs: {query.expected_docs - result.retrieved_docs})")
metrics.doc_retrieval_recall = len(result.retrieved_docs & query.expected_docs) / len(query.expected_docs)
metrics.doc_retrieval_precision = len(result.retrieved_docs & query.expected_docs) / len(
result.retrieved_docs
)

# Evaluate result
if not result.result:
console.print("[yellow] No query execution result available, skipping..", style="italic")
Expand Down Expand Up @@ -368,19 +384,41 @@ def query_all(self):

def print_metrics_summary(self):
"""Summarize metrics."""
# Plan metrics
console.rule("Evaluation summary")

# Plan metrics
plan_correct = sum(1 for result in self.results_map.values() if result.metrics.plan_similarity == 1.0)
console.print(f"Plans correct: {plan_correct}/{len(self.results_map)}")
average_plan_correctness = sum(result.metrics.plan_similarity for result in self.results_map.values()) / len(
self.results_map
)
average_plan_correctness = sum(
result.metrics.plan_similarity for result in self.results_map.values() if result.metrics.plan_similarity
) / len(self.results_map)
console.print(f"Avg. plan correctness: {average_plan_correctness}")
console.print(
"Avg. plan diff count: "
f"{sum(result.metrics.plan_diff_count for result in self.results_map.values()) / len(self.results_map)}"
+ str(
sum(
result.metrics.plan_diff_count
for result in self.results_map.values()
if result.metrics.plan_diff_count
)
/ len(self.results_map)
)
)

# Evaluate doc retrieval
correct_retrievals = sum(
1 for result in self.results_map.values() if result.metrics.doc_retrieval_recall == 1.0
)
expected_retrievals = sum(1 for result in self.results_map.values() if result.query.expected_docs)
average_precision = (
sum(
result.metrics.doc_retrieval_precision
for result in self.results_map.values()
if result.metrics.doc_retrieval_precision
)
/ expected_retrievals
)
console.print(f"Successful doc retrievals: {correct_retrievals}/{expected_retrievals}")
console.print(f"Average precision: {average_precision}")
# TODO: Query execution metrics
console.print("Query result correctness: not implemented")

Expand Down
74 changes: 74 additions & 0 deletions apps/query-eval/queryeval/test-results.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
config:
config_file: /private/var/folders/rd/5lfl86jj2yb3gn2zwwgqjgrc0000gn/T/pytest-of-vinayakthapliyal/pytest-108/test_driver_do_query0/test_input.yaml
doc_limit: null
dry_run: false
index: test-index
llm: null
llm_cache_path: null
log_file: null
natural_language_response: true
overwrite: false
query_cache_path: null
results_file: test-results.yaml
tags: null
data_schema:
fields:
field1:
description: text
examples: null
field_type: <class 'str'>
field2:
description: keyword
examples: null
field_type: <class 'str'>
examples: null
results:
- error: null
metrics:
correctness_score: null
doc_retrieval_precision: null
doc_retrieval_recall: null
plan_diff_count: null
plan_generation_time: null
plan_similarity: null
query_time: null
similarity_score: null
notes: null
plan: null
query:
expected: null
expected_docs:
- doc1.pdf
- doc2.pdf
expected_plan: null
notes: null
plan: null
query: test query 1
tags:
- test
result: null
retrieved_docs: null
timestamp: '2024-11-05T02:46:16.861230+00:00'
- error: null
metrics:
correctness_score: null
doc_retrieval_precision: null
doc_retrieval_recall: null
plan_diff_count: null
plan_generation_time: null
plan_similarity: null
query_time: null
similarity_score: null
notes: null
plan: null
query:
expected: null
expected_docs: null
expected_plan: null
notes: null
plan: null
query: test query 2
tags: null
result: null
retrieved_docs: null
timestamp: '2024-11-05T02:46:16.861896+00:00'
30 changes: 25 additions & 5 deletions apps/query-eval/queryeval/test_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from unittest.mock import MagicMock, patch

from sycamore.query.schema import OpenSearchSchema, OpenSearchSchemaField

from queryeval.driver import QueryEvalDriver
from queryeval.types import (
QueryEvalQuery,
Expand All @@ -14,7 +16,13 @@
@pytest.fixture
def mock_client():
client = MagicMock()
client.get_opensearch_schema.return_value = {"field1": "text", "field2": "keyword"}
schema = OpenSearchSchema(
fields={
"field1": OpenSearchSchemaField(field_type="<class 'str'>", description="text"),
"field2": OpenSearchSchemaField(field_type="<class 'str'>", description="keyword"),
}
)
client.get_opensearch_schema.return_value = schema
return client


Expand All @@ -29,11 +37,13 @@ def test_input_file(tmp_path):
input_file.write_text(
"""
config:

index: test-index
results_file: test-results.yaml
queries:
- query: "test query 1"
tags: ["test"]
expected_docs: ["doc1.pdf", "doc2.pdf"]
- query: "test query 2"
"""
)
Expand All @@ -46,8 +56,10 @@ def test_driver_init(test_input_file, mock_client):

assert driver.config.config.index == "test-index"
assert driver.config.config.doc_limit == 10
assert driver.config.config.tags == ["test"]
assert len(driver.config.queries) == 2
assert driver.config.queries[0].tags == ["test"]
assert driver.config.queries[1].tags is None
assert driver.config.queries[0].expected_docs == {"doc1.pdf", "doc2.pdf"}
assert driver.config.queries[1].expected_docs is None


def test_driver_do_plan(test_input_file, mock_client, mock_plan):
Expand Down Expand Up @@ -75,15 +87,23 @@ def test_driver_do_query(test_input_file, mock_client, mock_plan):
mock_client.run_plan.return_value = mock_query_result

result = driver.do_query(query, result)
driver.eval_all() # we are only asserting that this runs
assert result.result == "test result"


def test_driver_do_eval(test_input_file, mock_client, mock_plan):
with patch("queryeval.driver.SycamoreQueryClient", return_value=mock_client):
driver = QueryEvalDriver(input_file_path=test_input_file)
expected_docs = {"doc1.pdf", "doc2.pdf"}
query = QueryEvalQuery(query="test query", expected_plan=mock_plan, expected_docs=expected_docs)
result = QueryEvalResult(query=query, plan=mock_plan, retrieved_docs=expected_docs)

result = driver.do_eval(query, result)
assert result.metrics.plan_similarity == 1.0
assert result.metrics.doc_retrieval_recall == 1.0

query = QueryEvalQuery(query="test query", expected_plan=mock_plan)
result = QueryEvalResult(query=query, plan=mock_plan)
result = QueryEvalResult(query=query, plan=mock_plan, retrieved_docs={"doc1.pdf"})

result = driver.do_eval(query, result)
assert result.metrics.plan_similarity == 1.0
assert result.metrics.doc_retrieval_recall == 0.5
10 changes: 10 additions & 0 deletions apps/query-eval/queryeval/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class QueryEvalQuery(BaseModel):
query: str
expected: Optional[Union[str, List[Dict[str, Any]]]] = None
expected_plan: Optional[LogicalPlan] = None
expected_docs: Optional[Set[str]] = None
plan: Optional[LogicalPlan] = None
tags: Optional[List[str]] = None
notes: Optional[str] = None
Expand All @@ -50,10 +51,19 @@ class QueryEvalInputFile(BaseModel):
class QueryEvalMetrics(BaseModel):
"""Represents metrics associated with a result."""

# Plan metrics
plan_generation_time: Optional[float] = None
plan_similarity: Optional[float] = None
plan_diff_count: Optional[int] = None

# Documenet retrieval metrics
doc_retrieval_recall: Optional[float] = None
doc_retrieval_precision: Optional[float] = None

# Performance metrics
query_time: Optional[float] = None

# String answer metrics
correctness_score: Optional[float] = None
similarity_score: Optional[float] = None

Expand Down
Loading