From 535e45c418a8008ef9f9a352641c9e75e8e7bda0 Mon Sep 17 00:00:00 2001 From: Emmett McFaralne Date: Sun, 1 Sep 2024 17:19:43 -0400 Subject: [PATCH] replaced DEFAULT_VLM with DEFAULT_AI_MODEL, added unit test for JSON parsing --- README.md | 2 +- tests/test_extractor.py | 33 ++++++++++++++++++++++++++++++++- tests/test_scraper.py | 30 +++++++++++++++++++++--------- thepipe/extract.py | 5 +++-- thepipe/scraper.py | 14 +++++++------- 5 files changed, 64 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 2e9fed9..775ec25 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ For a local installation, you can use the following command: pip install thepipe-api[local] ``` -You must have a local LLM server setup and running for AI extraction features. You can use any local LLM server that follows OpenAI format (such as [LiteLLM](https://github.com/BerriAI/litellm) or [OpenRouter](https://openrouter.ai/)). Next, set the `LLM_SERVER_BASE_URL` environment variable to your LLM server's endpoint URL and set `LLM_SERVER_API_KEY` to the API key for your LLM of choice. the `DEFAULT_VLM` environment variable can be set to the model name of your LLM. For example, you may use `openai/gpt-4o-mini` if using OpenRouter or `gpt-4o-mini` if using OpenAI. +You must have a local LLM server setup and running for AI extraction features. You can use any local LLM server that follows OpenAI format (such as [LiteLLM](https://github.com/BerriAI/litellm) or [OpenRouter](https://openrouter.ai/)). Next, set the `LLM_SERVER_BASE_URL` environment variable to your LLM server's endpoint URL and set `LLM_SERVER_API_KEY` to the API key for your LLM of choice. the `DEFAULT_AI_MODEL` environment variable can be set to the model name of your LLM. For example, you may use `openai/gpt-4o-mini` if using OpenRouter or `gpt-4o-mini` if using OpenAI. For full functionality with media-rich sources, you will need to install the following dependencies: diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 5d74ec3..8210d41 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -5,7 +5,7 @@ import os import json sys.path.append('..') -from thepipe.extract import extract +from thepipe.extract import extract, extract_json_from_response from thepipe.core import Chunk class TestExtractor(unittest.TestCase): @@ -27,6 +27,37 @@ def setUp(self): self.chunks = [Chunk(path="receipt.md", texts=[self.example_receipt])] + def test_extract_json_from_response(self): + # List of test cases with expected results + test_cases = [ + # Case 1: JSON enclosed in triple backticks + { + "input": "```json\n{\"key1\": \"value1\", \"key2\": 2}\n```", + "expected": {"key1": "value1", "key2": 2} + }, + # Case 2: JSON directly in the response + { + "input": "{\"key1\": \"value1\", \"key2\": 2}", + "expected": {"key1": "value1", "key2": 2} + }, + # Case 3: Response contains multiple JSON objects + { + "input": "Random text {\"key1\": \"value1\"} and another {\"key2\": 2}", + "expected": [{"key1": "value1"}, {"key2": 2}] + }, + # Case 4: Response contains incomplete JSON + { + "input": "Random text {\"key1\": \"value1\"} and another {\"key2\": 2", + "expected": {"key1": "value1"} + } + + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + result = extract_json_from_response(case["input"]) + self.assertEqual(result, case["expected"]) + def test_extract(self): results, total_tokens_used = extract( chunks=self.chunks, diff --git a/tests/test_scraper.py b/tests/test_scraper.py index e2a9615..cb3010f 100644 --- a/tests/test_scraper.py +++ b/tests/test_scraper.py @@ -40,15 +40,15 @@ def test_scrape_ipynb(self): self.assertTrue(any(len(chunk.images) > 0 for chunk in chunks)) # requires modal token to run - #def test_scrape_pdf_with_ai_extraction(self): - # chunks = scraper.scrape_file("tests/files/example.pdf", ai_extraction=True, verbose=True, local=True) - # # verify it scraped the pdf file into chunks - # self.assertEqual(type(chunks), list) - # self.assertNotEqual(len(chunks), 0) - # self.assertEqual(type(chunks[0]), core.Chunk) - # # verify it scraped the data - # for chunk in chunks: - # self.assertIsNotNone(chunk.texts or chunk.images) + def test_scrape_pdf_with_ai_extraction(self): + chunks = scraper.scrape_file("tests/files/example.pdf", ai_extraction=True, verbose=True, local=True) + # verify it scraped the pdf file into chunks + self.assertEqual(type(chunks), list) + self.assertNotEqual(len(chunks), 0) + self.assertEqual(type(chunks[0]), core.Chunk) + # verify it scraped the data + for chunk in chunks: + self.assertIsNotNone(chunk.texts or chunk.images) def test_scrape_docx(self): chunks = scraper.scrape_file(self.files_directory+"/example.docx", verbose=True, local=True) @@ -148,6 +148,18 @@ def test_scrape_url(self): chunks = scraper.scrape_url('https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf', local=True) self.assertEqual(len(chunks), 1) + def test_scrape_url_with_ai_extraction(self): + # verify web page scrape result with ai extraction + chunks = scraper.scrape_url('https://en.wikipedia.org/wiki/Piping', ai_extraction=True, local=True) + for chunk in chunks: + self.assertEqual(type(chunk), core.Chunk) + self.assertEqual(chunk.path, 'https://en.wikipedia.org/wiki/Piping') + # assert if any of the texts in chunk.texts contains 'pipe' + self.assertGreater(len(chunk.texts), 0) + self.assertIn('pipe', chunk.texts[0]) + # verify if at least one image was scraped + self.assertTrue(any(len(chunk.images) > 0 for chunk in chunks)) + @unittest.skipUnless(os.environ.get('GITHUB_TOKEN'), "requires GITHUB_TOKEN") def test_scrape_github(self): chunks = scraper.scrape_url('https://github.com/emcf/thepipe', local=True) diff --git a/thepipe/extract.py b/thepipe/extract.py index fc1afb6..0355b51 100644 --- a/thepipe/extract.py +++ b/thepipe/extract.py @@ -10,8 +10,9 @@ from openai import OpenAI DEFAULT_EXTRACTION_PROMPT = "Extract structured information from the above document according to the following schema: {schema}. Immediately return valid JSON formatted data. If there is missing data, you may use null, but use your reasoning to always fill in every column as best you can. Always immediately return valid JSON." +DEFAULT_AI_MODEL = os.getenv("DEFAULT_AI_MODEL", "gpt-4o-mini") -def extract_json_from_response(llm_response: str) -> Optional[Dict]: +def extract_json_from_response(llm_response: str) -> Union[Dict, List[Dict], None]: def clean_response_text(llm_response: str) -> str: return llm_response.encode('utf-8', 'ignore').decode('utf-8') @@ -99,7 +100,7 @@ def extract_from_chunk(chunk: Chunk, chunk_index: int, schema: str, ai_model: st response_dict = {"chunk_index": chunk_index, "source": source, "error": str(e)} return response_dict, tokens_used -def extract(chunks: List[Chunk], schema: Union[str, Dict], ai_model: str = 'google/gemma-2-9b-it', multiple_extractions: bool = False, extraction_prompt: str = DEFAULT_EXTRACTION_PROMPT, host_images: bool = False) -> Tuple[List[Dict], int]: +def extract(chunks: List[Chunk], schema: Union[str, Dict], ai_model: Optional[str] = 'google/gemma-2-9b-it', multiple_extractions: Optional[bool] = False, extraction_prompt: Optional[str] = DEFAULT_EXTRACTION_PROMPT, host_images: Optional[bool] = False) -> Tuple[List[Dict], int]: if isinstance(schema, dict): schema = json.dumps(schema) diff --git a/thepipe/scraper.py b/thepipe/scraper.py index b33bdd9..5df71bf 100644 --- a/thepipe/scraper.py +++ b/thepipe/scraper.py @@ -35,7 +35,7 @@ EXTRACTION_PROMPT = os.getenv("EXTRACTION_PROMPT", """An open source document is given. Output the entire extracted contents from the document in detailed markdown format. Be sure to correctly format markdown for headers, paragraphs, lists, tables, menus, equations, full text contents, etc. Always reply immediately with only markdown. Do not output anything else.""") -DEFAULT_VLM = os.getenv("DEFAULT_VLM", "gpt-4o-mini") +DEFAULT_AI_MODEL = os.getenv("DEFAULT_AI_MODEL", "gpt-4o-mini") FILESIZE_LIMIT_MB = os.getenv("FILESIZE_LIMIT_MB", 50) def detect_source_type(source: str) -> str: @@ -55,7 +55,7 @@ def detect_source_type(source: str) -> str: mimetype = result.output.mime_type return mimetype -def scrape_file(filepath: str, ai_extraction: bool = False, text_only: bool = False, verbose: bool = False, local: bool = False, chunking_method: Optional[Callable] = chunk_by_page) -> List[Chunk]: +def scrape_file(filepath: str, ai_extraction: bool = False, text_only: bool = False, verbose: bool = False, local: bool = False, chunking_method: Optional[Callable] = chunk_by_page, ai_model: Optional[str] = DEFAULT_AI_MODEL) -> List[Chunk]: if not local: with open(filepath, 'rb') as f: response = requests.post( @@ -92,7 +92,7 @@ def scrape_file(filepath: str, ai_extraction: bool = False, text_only: bool = Fa if verbose: print(f"[thepipe] Scraping {source_type}: {filepath}...") if source_type == 'application/pdf': - scraped_chunks = scrape_pdf(file_path=filepath, ai_extraction=ai_extraction, text_only=text_only, verbose=verbose) + scraped_chunks = scrape_pdf(file_path=filepath, ai_extraction=ai_extraction, text_only=text_only, verbose=verbose, ai_model=ai_model) elif source_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': scraped_chunks = scrape_docx(file_path=filepath, verbose=verbose, text_only=text_only) elif source_type == 'application/vnd.openxmlformats-officedocument.presentationml.presentation': @@ -149,7 +149,7 @@ def scrape_zip(file_path: str, include_regex: Optional[str] = None, verbose: boo chunks = scrape_directory(dir_path=temp_dir, include_regex=include_regex, verbose=verbose, ai_extraction=ai_extraction, text_only=text_only, local=local) return chunks -def scrape_pdf(file_path: str, ai_extraction: bool = False, text_only: bool = False, verbose: bool = False) -> List[Chunk]: +def scrape_pdf(file_path: str, ai_extraction: Optional[bool] = False, text_only: Optional[bool] = False, ai_model: Optional[str] = DEFAULT_AI_MODEL, verbose: Optional[bool] = False) -> List[Chunk]: chunks = [] MAX_PAGES = 128 @@ -188,7 +188,7 @@ def process_page(page_num): }, ] response = openrouter_client.chat.completions.create( - model=DEFAULT_VLM, + model=ai_model, messages=messages, temperature=0.2 ) @@ -289,7 +289,7 @@ def scrape_spreadsheet(file_path: str, source_type: str) -> List[Chunk]: chunks.append(Chunk(path=file_path, texts=[item_json])) return chunks -def ai_extract_webpage_content(url: str, text_only: bool = False, verbose: bool = False) -> Chunk: +def ai_extract_webpage_content(url: str, text_only: Optional[bool] = False, verbose: Optional[bool] = False, ai_model: Optional[str] = DEFAULT_AI_MODEL) -> Chunk: from playwright.sync_api import sync_playwright import modal from openai import OpenAI @@ -352,7 +352,7 @@ def ai_extract_webpage_content(url: str, text_only: bool = False, verbose: bool }, ] response = openrouter_client.chat.completions.create( - model=DEFAULT_VLM, + model=ai_model, messages=messages, temperature=0.2 )