Skip to content

Commit

Permalink
Refactor: Fixing Tests that currently were not passing even though Op…
Browse files Browse the repository at this point in the history
…enAI API key was passed

- Implement abstractmethod in dae_evaluator and refactor to new OpenAI API
- Refactor dimension function in duckdb_adapter
- test refactoring: clinwar wrapper, dae evaluator, splitter
- Fix issues with readonly database and linting
  • Loading branch information
iQuxLE committed Sep 11, 2024
1 parent fa2c2c7 commit ab5763d
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 62 deletions.
6 changes: 2 additions & 4 deletions src/curate_gpt/evaluation/dae_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
import logging
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import List, TextIO

Expand Down Expand Up @@ -50,8 +49,7 @@ def evaluate(
"""
agent = self.agent
db = agent.knowledge_source
# TODO: use get()
test_objs = list(db.peek(collection=test_collection, limit=num_tests))
test_objs = list(db.find(collection=test_collection))
if any(obj for obj in test_objs if any(f not in obj for f in self.fields_to_predict)):
logger.info("Alternate strategy to get test objs; query whole collection")
test_objs = db.peek(collection=test_collection, limit=1000000)
Expand Down Expand Up @@ -135,4 +133,4 @@ def evaluate(
return aggregated

def evaluate_object(self, obj, **kwargs) -> ClassificationMetrics:
raise NotImplementedError
pass
2 changes: 1 addition & 1 deletion src/curate_gpt/evaluation/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def stratify_collection_to_store(
size = len(objs)
cn = f"{collection}_{sn}_{size}"
collections[sn] = cn
logging.info(f"Writing {size} objects to {cn}")
logger.info(f"Writing {size} objects to {cn}")
if cn in existing_collections:
logger.info(f"Collection {cn} already exists")
if not force:
Expand Down
1 change: 0 additions & 1 deletion src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,6 @@ def _get_embedding_dimension(self, model_name: str) -> int:
if model_key == "" or model_key not in MODEL_MAP.keys():
model_key = DEFAULT_OPENAI_MODEL
model_info = MODEL_MAP.get(model_key, DEFAULT_OPENAI_MODEL)
print(f"Model info: {model_info}")
return model_info[1]
else:
return MODEL_MAP[DEFAULT_OPENAI_MODEL][1]
Expand Down
6 changes: 3 additions & 3 deletions src/curate_gpt/wrappers/clinical/clinvar_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def objects_from_dict(self, results: Dict) -> List[Dict]:
for r in results["eSummaryResult"]["DocumentSummarySet"]["DocumentSummary"]:
obj = {}
obj["id"] = "clinvar:" + r["accession"]
obj["clinical_significance"] = r["clinical_significance"]["description"]
obj["clinical_significance_status"] = r["clinical_significance"]["review_status"]
obj["clinical_significance"] = r["germline_classification"]["description"]
obj["clinical_significance_status"] = r["germline_classification"]["review_status"]
obj["gene_sort"] = r["gene_sort"]
if "genes" in r and r["genes"]:
if "gene" in r["genes"]:
Expand All @@ -46,7 +46,7 @@ def objects_from_dict(self, results: Dict) -> List[Dict]:
obj["protein_change"] = r["protein_change"]
obj["title"] = r["title"]
obj["traits"] = [
self._trait_from_dict(t) for t in r["trait_set"]["trait"] if isinstance(t, dict)
self._trait_from_dict(t) for t in r.get("trait_set", {}).get("trait", []) if isinstance(t, dict)
]
objs.append(obj)
return objs
Expand Down
1 change: 0 additions & 1 deletion tests/store/test_duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def test_the_embedding_function_variations(
expected_name = "test_collection"
else:
# Specific case: Collection specified, model may or may not be specified
print("\n\n",model,"\n\n")
db.insert(objs, collection=collection, model=model)
expected_model = model if model else "all-MiniLM-L6-v2"
expected_name = collection
Expand Down
15 changes: 9 additions & 6 deletions tests/wrappers/test_clinvar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import shutil
import os
import tempfile
import time

import pytest
Expand Down Expand Up @@ -30,11 +31,13 @@ def test_clinvar_transform():

@pytest.fixture
def wrapper() -> ClinVarWrapper:
shutil.rmtree(TEMP_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_DB))
extractor = BasicExtractor()
db.reset()
return ClinVarWrapper(local_store=db, extractor=extractor)
with tempfile.TemporaryDirectory() as temp_dir:
db_path = os.path.join(temp_dir, TEMP_DB)
# shutil.rmtree(TEMP_DB, ignore_errors=True)
db = ChromaDBAdapter(db_path)
extractor = BasicExtractor()
db.reset()
return ClinVarWrapper(local_store=db, extractor=extractor)


@requires_openai_api_key
Expand Down
33 changes: 18 additions & 15 deletions tests/wrappers/test_evidence_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import shutil
import os
import tempfile
from typing import Type

import pytest
Expand Down Expand Up @@ -31,17 +32,19 @@
],
)
def test_evidence_inference(source: Type[BaseWrapper]):
shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_PUBMED_DB))
extractor = BasicExtractor()
db.reset()
pubmed = source(local_store=db, extractor=extractor)
ea = EvidenceAgent(chat_agent=pubmed)
obj = {
"label": "acinar cells of the salivary gland",
"relationships": [
{"predicate": "HasFunction", "object": "ManufactureSaliva"},
],
}
resp = ea.find_evidence(obj)
print(yaml.dump(resp))
with tempfile.TemporaryDirectory() as temp_dir:
db_path = os.path.join(temp_dir, TEMP_PUBMED_DB)
# shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(db_path)
extractor = BasicExtractor()
db.reset()
pubmed = source(local_store=db, extractor=extractor)
ea = EvidenceAgent(chat_agent=pubmed)
obj = {
"label": "acinar cells of the salivary gland",
"relationships": [
{"predicate": "HasFunction", "object": "ManufactureSaliva"},
],
}
resp = ea.find_evidence(obj)
print(yaml.dump(resp))
24 changes: 14 additions & 10 deletions tests/wrappers/test_ncbi_biosample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import shutil
import tempfile
import time

import yaml
Expand Down Expand Up @@ -34,13 +36,15 @@ def test_biosample_search():

@requires_openai_api_key
def test_biosample_chat():
shutil.rmtree(TEMP_BIOSAMPLE_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_BIOSAMPLE_DB))
extractor = BasicExtractor()
db.reset()
wrapper = NCBIBiosampleWrapper(local_store=db, extractor=extractor)
chat = ChatAgent(knowledge_source=wrapper, extractor=extractor)
response = chat.chat("what are some characteristics of the gut microbiome in Crohn's disease?")
print(response.formatted_body)
for ref in response.references:
print(ref)
with tempfile.TemporaryDirectory() as temp_dir:
db_path = os.path.join(temp_dir, TEMP_BIOSAMPLE_DB)
# shutil.rmtree(TEMP_BIOSAMPLE_DB, ignore_errors=True)
db = ChromaDBAdapter(db_path)
extractor = BasicExtractor()
db.reset()
wrapper = NCBIBiosampleWrapper(local_store=db, extractor=extractor)
chat = ChatAgent(knowledge_source=wrapper, extractor=extractor)
response = chat.chat("what are some characteristics of the gut microbiome in Crohn's disease?")
print(response.formatted_body)
for ref in response.references:
print(ref)
47 changes: 26 additions & 21 deletions tests/wrappers/test_pubmed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import shutil
import os
import tempfile
import time

from curate_gpt import ChromaDBAdapter
Expand Down Expand Up @@ -43,27 +44,31 @@ def test_full_text():

@requires_openai_api_key
def test_pubmed_search():
shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_PUBMED_DB))
extractor = BasicExtractor()
db.reset()
pubmed = PubmedWrapper(local_store=db, extractor=extractor)
results = list(pubmed.search("acinar cells of the salivary gland"))
assert len(results) > 0
top_result = results[0][0]
print(top_result)
time.sleep(0.5)
results2 = list(pubmed.search(top_result["title"]))
assert len(results2) > 0
with tempfile.TemporaryDirectory() as temp_dir:
db_path = os.path.join(temp_dir, TEMP_PUBMED_DB)
# shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(db_path)
extractor = BasicExtractor()
db.reset()
pubmed = PubmedWrapper(local_store=db, extractor=extractor)
results = list(pubmed.search("acinar cells of the salivary gland"))
assert len(results) > 0
top_result = results[0][0]
print(top_result)
time.sleep(0.5)
results2 = list(pubmed.search(top_result["title"]))
assert len(results2) > 0


@requires_openai_api_key
def test_pubmed_chat():
shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_PUBMED_DB))
extractor = BasicExtractor()
db.reset()
pubmed = PubmedWrapper(local_store=db, extractor=extractor)
chat = ChatAgent(knowledge_source=pubmed, extractor=extractor)
response = chat.chat("what diseases are associated with acinar cells of the salivary gland")
print(response)
with tempfile.TemporaryDirectory() as temp_dir:
db_path = os.path.join(temp_dir, TEMP_PUBMED_DB)
# shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(db_path)
extractor = BasicExtractor()
db.reset()
pubmed = PubmedWrapper(local_store=db, extractor=extractor)
chat = ChatAgent(knowledge_source=pubmed, extractor=extractor)
response = chat.chat("what diseases are associated with acinar cells of the salivary gland")
print(response)

0 comments on commit ab5763d

Please sign in to comment.