Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR comments
  • Loading branch information
baitsguy committed Nov 5, 2024
commit bdad2a3fd8a57b792f2ac73e3e1337955c764be6
20 changes: 16 additions & 4 deletions apps/query-eval/queryeval/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ def do_eval(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalRe
console.print("[red]:x: Document retrieval mismatch")
console.print(f"Missing docs: {query.expected_docs - result.retrieved_docs})")
metrics.doc_retrieval_recall = len(result.retrieved_docs & query.expected_docs) / len(query.expected_docs)
metrics.doc_retrieval_precision = len(result.retrieved_docs & query.expected_docs) / len(
result.retrieved_docs
)

# Evaluate result
if not result.result:
Expand Down Expand Up @@ -386,20 +389,29 @@ def print_metrics_summary(self):
# Plan metrics
plan_correct = sum(1 for result in self.results_map.values() if result.metrics.plan_similarity == 1.0)
console.print(f"Plans correct: {plan_correct}/{len(self.results_map)}")
average_plan_correctness = sum(result.metrics.plan_similarity for result in self.results_map.values()) / len(
self.results_map
)
average_plan_correctness = sum(
result.metrics.plan_similarity for result in self.results_map.values() if result.metrics.plan_similarity
) / len(self.results_map)
console.print(f"Avg. plan correctness: {average_plan_correctness}")
console.print(
"Avg. plan diff count: "
f"{sum(result.metrics.plan_diff_count for result in self.results_map.values()) / len(self.results_map)}"
f"{sum(result.metrics.plan_diff_count for result in self.results_map.values() if result.metrics.plan_diff_count) / len(self.results_map)}"
)
# Evaluate doc retrieval
correct_retrievals = sum(
1 for result in self.results_map.values() if result.metrics.doc_retrieval_recall == 1.0
)
expected_retrievals = sum(1 for result in self.results_map.values() if result.query.expected_docs)
average_precision = (
sum(
result.metrics.doc_retrieval_precision
for result in self.results_map.values()
if result.metrics.doc_retrieval_precision
)
/ expected_retrievals
)
console.print(f"Successful doc retrievals: {correct_retrievals}/{expected_retrievals}")
console.print(f"Average precision: {average_precision}")
# TODO: Query execution metrics
console.print("Query result correctness: not implemented")

Expand Down
74 changes: 74 additions & 0 deletions apps/query-eval/queryeval/test-results.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
config:
config_file: /private/var/folders/rd/5lfl86jj2yb3gn2zwwgqjgrc0000gn/T/pytest-of-vinayakthapliyal/pytest-108/test_driver_do_query0/test_input.yaml
doc_limit: null
dry_run: false
index: test-index
llm: null
llm_cache_path: null
log_file: null
natural_language_response: true
overwrite: false
query_cache_path: null
results_file: test-results.yaml
tags: null
data_schema:
fields:
field1:
description: text
examples: null
field_type: <class 'str'>
field2:
description: keyword
examples: null
field_type: <class 'str'>
examples: null
results:
- error: null
metrics:
correctness_score: null
doc_retrieval_precision: null
doc_retrieval_recall: null
plan_diff_count: null
plan_generation_time: null
plan_similarity: null
query_time: null
similarity_score: null
notes: null
plan: null
query:
expected: null
expected_docs:
- doc1.pdf
- doc2.pdf
expected_plan: null
notes: null
plan: null
query: test query 1
tags:
- test
result: null
retrieved_docs: null
timestamp: '2024-11-05T02:46:16.861230+00:00'
- error: null
metrics:
correctness_score: null
doc_retrieval_precision: null
doc_retrieval_recall: null
plan_diff_count: null
plan_generation_time: null
plan_similarity: null
query_time: null
similarity_score: null
notes: null
plan: null
query:
expected: null
expected_docs: null
expected_plan: null
notes: null
plan: null
query: test query 2
tags: null
result: null
retrieved_docs: null
timestamp: '2024-11-05T02:46:16.861896+00:00'
12 changes: 11 additions & 1 deletion apps/query-eval/queryeval/test_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from unittest.mock import MagicMock, patch

from sycamore.query.schema import OpenSearchSchema, OpenSearchSchemaField

from queryeval.driver import QueryEvalDriver
from queryeval.types import (
QueryEvalQuery,
Expand All @@ -14,7 +16,13 @@
@pytest.fixture
def mock_client():
client = MagicMock()
client.get_opensearch_schema.return_value = {"field1": "text", "field2": "keyword"}
schema = OpenSearchSchema(
fields={
"field1": OpenSearchSchemaField(field_type="<class 'str'>", description="text"),
"field2": OpenSearchSchemaField(field_type="<class 'str'>", description="keyword"),
}
)
client.get_opensearch_schema.return_value = schema
return client


Expand All @@ -29,6 +37,7 @@ def test_input_file(tmp_path):
input_file.write_text(
"""
config:

index: test-index
results_file: test-results.yaml
queries:
Expand Down Expand Up @@ -78,6 +87,7 @@ def test_driver_do_query(test_input_file, mock_client, mock_plan):
mock_client.run_plan.return_value = mock_query_result

result = driver.do_query(query, result)
driver.eval_all() # we are only asserting that this runs
assert result.result == "test result"


Expand Down
1 change: 1 addition & 0 deletions apps/query-eval/queryeval/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class QueryEvalMetrics(BaseModel):

# Documenet retrieval metrics
doc_retrieval_recall: Optional[float] = None
doc_retrieval_precision: Optional[float] = None

# Performance metrics
query_time: Optional[float] = None
Expand Down