diff --git a/apps/query-eval/queryeval/driver.py b/apps/query-eval/queryeval/driver.py index 3e89a2d0f..6e7e0dfef 100644 --- a/apps/query-eval/queryeval/driver.py +++ b/apps/query-eval/queryeval/driver.py @@ -330,6 +330,22 @@ def do_eval(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalRe ) metrics.plan_diff_count = len(plan_diff) + # Evaluate doc retrieval + if not query.expected_docs: + console.print("[yellow]:construction: No expected document list found, skipping.. ") + elif not result.retrieved_docs: + console.print("[yellow]:construction: No computed document list found, skipping.. ") + else: + if query.expected_docs.issubset(result.retrieved_docs): + console.print("[green]✔ Documents retrieved match") + else: + console.print("[red]:x: Document retrieval mismatch") + console.print(f"Missing docs: {query.expected_docs - result.retrieved_docs})") + metrics.doc_retrieval_recall = len(result.retrieved_docs & query.expected_docs) / len(query.expected_docs) + metrics.doc_retrieval_precision = len(result.retrieved_docs & query.expected_docs) / len( + result.retrieved_docs + ) + # Evaluate result if not result.result: console.print("[yellow] No query execution result available, skipping..", style="italic") @@ -368,19 +384,41 @@ def query_all(self): def print_metrics_summary(self): """Summarize metrics.""" - # Plan metrics console.rule("Evaluation summary") + + # Plan metrics plan_correct = sum(1 for result in self.results_map.values() if result.metrics.plan_similarity == 1.0) console.print(f"Plans correct: {plan_correct}/{len(self.results_map)}") - average_plan_correctness = sum(result.metrics.plan_similarity for result in self.results_map.values()) / len( - self.results_map - ) + average_plan_correctness = sum( + result.metrics.plan_similarity for result in self.results_map.values() if result.metrics.plan_similarity + ) / len(self.results_map) console.print(f"Avg. plan correctness: {average_plan_correctness}") console.print( "Avg. plan diff count: " - f"{sum(result.metrics.plan_diff_count for result in self.results_map.values()) / len(self.results_map)}" + + str( + sum( + result.metrics.plan_diff_count + for result in self.results_map.values() + if result.metrics.plan_diff_count + ) + / len(self.results_map) + ) ) - + # Evaluate doc retrieval + correct_retrievals = sum( + 1 for result in self.results_map.values() if result.metrics.doc_retrieval_recall == 1.0 + ) + expected_retrievals = sum(1 for result in self.results_map.values() if result.query.expected_docs) + average_precision = ( + sum( + result.metrics.doc_retrieval_precision + for result in self.results_map.values() + if result.metrics.doc_retrieval_precision + ) + / expected_retrievals + ) + console.print(f"Successful doc retrievals: {correct_retrievals}/{expected_retrievals}") + console.print(f"Average precision: {average_precision}") # TODO: Query execution metrics console.print("Query result correctness: not implemented") diff --git a/apps/query-eval/queryeval/test-results.yaml b/apps/query-eval/queryeval/test-results.yaml new file mode 100644 index 000000000..03415fe5b --- /dev/null +++ b/apps/query-eval/queryeval/test-results.yaml @@ -0,0 +1,74 @@ +config: + config_file: /private/var/folders/rd/5lfl86jj2yb3gn2zwwgqjgrc0000gn/T/pytest-of-vinayakthapliyal/pytest-108/test_driver_do_query0/test_input.yaml + doc_limit: null + dry_run: false + index: test-index + llm: null + llm_cache_path: null + log_file: null + natural_language_response: true + overwrite: false + query_cache_path: null + results_file: test-results.yaml + tags: null +data_schema: + fields: + field1: + description: text + examples: null + field_type: + field2: + description: keyword + examples: null + field_type: +examples: null +results: +- error: null + metrics: + correctness_score: null + doc_retrieval_precision: null + doc_retrieval_recall: null + plan_diff_count: null + plan_generation_time: null + plan_similarity: null + query_time: null + similarity_score: null + notes: null + plan: null + query: + expected: null + expected_docs: + - doc1.pdf + - doc2.pdf + expected_plan: null + notes: null + plan: null + query: test query 1 + tags: + - test + result: null + retrieved_docs: null + timestamp: '2024-11-05T02:46:16.861230+00:00' +- error: null + metrics: + correctness_score: null + doc_retrieval_precision: null + doc_retrieval_recall: null + plan_diff_count: null + plan_generation_time: null + plan_similarity: null + query_time: null + similarity_score: null + notes: null + plan: null + query: + expected: null + expected_docs: null + expected_plan: null + notes: null + plan: null + query: test query 2 + tags: null + result: null + retrieved_docs: null + timestamp: '2024-11-05T02:46:16.861896+00:00' diff --git a/apps/query-eval/queryeval/test_driver.py b/apps/query-eval/queryeval/test_driver.py index cdb35a03e..92c293fa4 100644 --- a/apps/query-eval/queryeval/test_driver.py +++ b/apps/query-eval/queryeval/test_driver.py @@ -1,6 +1,8 @@ import pytest from unittest.mock import MagicMock, patch +from sycamore.query.schema import OpenSearchSchema, OpenSearchSchemaField + from queryeval.driver import QueryEvalDriver from queryeval.types import ( QueryEvalQuery, @@ -14,7 +16,13 @@ @pytest.fixture def mock_client(): client = MagicMock() - client.get_opensearch_schema.return_value = {"field1": "text", "field2": "keyword"} + schema = OpenSearchSchema( + fields={ + "field1": OpenSearchSchemaField(field_type="", description="text"), + "field2": OpenSearchSchemaField(field_type="", description="keyword"), + } + ) + client.get_opensearch_schema.return_value = schema return client @@ -29,11 +37,13 @@ def test_input_file(tmp_path): input_file.write_text( """ config: + index: test-index results_file: test-results.yaml queries: - query: "test query 1" tags: ["test"] + expected_docs: ["doc1.pdf", "doc2.pdf"] - query: "test query 2" """ ) @@ -46,8 +56,10 @@ def test_driver_init(test_input_file, mock_client): assert driver.config.config.index == "test-index" assert driver.config.config.doc_limit == 10 - assert driver.config.config.tags == ["test"] - assert len(driver.config.queries) == 2 + assert driver.config.queries[0].tags == ["test"] + assert driver.config.queries[1].tags is None + assert driver.config.queries[0].expected_docs == {"doc1.pdf", "doc2.pdf"} + assert driver.config.queries[1].expected_docs is None def test_driver_do_plan(test_input_file, mock_client, mock_plan): @@ -75,15 +87,23 @@ def test_driver_do_query(test_input_file, mock_client, mock_plan): mock_client.run_plan.return_value = mock_query_result result = driver.do_query(query, result) + driver.eval_all() # we are only asserting that this runs assert result.result == "test result" def test_driver_do_eval(test_input_file, mock_client, mock_plan): with patch("queryeval.driver.SycamoreQueryClient", return_value=mock_client): driver = QueryEvalDriver(input_file_path=test_input_file) + expected_docs = {"doc1.pdf", "doc2.pdf"} + query = QueryEvalQuery(query="test query", expected_plan=mock_plan, expected_docs=expected_docs) + result = QueryEvalResult(query=query, plan=mock_plan, retrieved_docs=expected_docs) + + result = driver.do_eval(query, result) + assert result.metrics.plan_similarity == 1.0 + assert result.metrics.doc_retrieval_recall == 1.0 - query = QueryEvalQuery(query="test query", expected_plan=mock_plan) - result = QueryEvalResult(query=query, plan=mock_plan) + result = QueryEvalResult(query=query, plan=mock_plan, retrieved_docs={"doc1.pdf"}) result = driver.do_eval(query, result) assert result.metrics.plan_similarity == 1.0 + assert result.metrics.doc_retrieval_recall == 0.5 diff --git a/apps/query-eval/queryeval/types.py b/apps/query-eval/queryeval/types.py index 8d558e65c..ff47f487d 100644 --- a/apps/query-eval/queryeval/types.py +++ b/apps/query-eval/queryeval/types.py @@ -33,6 +33,7 @@ class QueryEvalQuery(BaseModel): query: str expected: Optional[Union[str, List[Dict[str, Any]]]] = None expected_plan: Optional[LogicalPlan] = None + expected_docs: Optional[Set[str]] = None plan: Optional[LogicalPlan] = None tags: Optional[List[str]] = None notes: Optional[str] = None @@ -50,10 +51,19 @@ class QueryEvalInputFile(BaseModel): class QueryEvalMetrics(BaseModel): """Represents metrics associated with a result.""" + # Plan metrics plan_generation_time: Optional[float] = None plan_similarity: Optional[float] = None plan_diff_count: Optional[int] = None + + # Documenet retrieval metrics + doc_retrieval_recall: Optional[float] = None + doc_retrieval_precision: Optional[float] = None + + # Performance metrics query_time: Optional[float] = None + + # String answer metrics correctness_score: Optional[float] = None similarity_score: Optional[float] = None