Skip to content

Commit

Permalink
replaced DEFAULT_VLM with DEFAULT_AI_MODEL, added unit test for JSON …
Browse files Browse the repository at this point in the history
…parsing
  • Loading branch information
emcf committed Sep 1, 2024
1 parent a78e194 commit 535e45c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ For a local installation, you can use the following command:
pip install thepipe-api[local]
```

You must have a local LLM server setup and running for AI extraction features. You can use any local LLM server that follows OpenAI format (such as [LiteLLM](https://github.com/BerriAI/litellm) or [OpenRouter](https://openrouter.ai/)). Next, set the `LLM_SERVER_BASE_URL` environment variable to your LLM server's endpoint URL and set `LLM_SERVER_API_KEY` to the API key for your LLM of choice. the `DEFAULT_VLM` environment variable can be set to the model name of your LLM. For example, you may use `openai/gpt-4o-mini` if using OpenRouter or `gpt-4o-mini` if using OpenAI.
You must have a local LLM server setup and running for AI extraction features. You can use any local LLM server that follows OpenAI format (such as [LiteLLM](https://github.com/BerriAI/litellm) or [OpenRouter](https://openrouter.ai/)). Next, set the `LLM_SERVER_BASE_URL` environment variable to your LLM server's endpoint URL and set `LLM_SERVER_API_KEY` to the API key for your LLM of choice. the `DEFAULT_AI_MODEL` environment variable can be set to the model name of your LLM. For example, you may use `openai/gpt-4o-mini` if using OpenRouter or `gpt-4o-mini` if using OpenAI.

For full functionality with media-rich sources, you will need to install the following dependencies:

Expand Down
33 changes: 32 additions & 1 deletion tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import json
sys.path.append('..')
from thepipe.extract import extract
from thepipe.extract import extract, extract_json_from_response
from thepipe.core import Chunk

class TestExtractor(unittest.TestCase):
Expand All @@ -27,6 +27,37 @@ def setUp(self):

self.chunks = [Chunk(path="receipt.md", texts=[self.example_receipt])]

def test_extract_json_from_response(self):
# List of test cases with expected results
test_cases = [
# Case 1: JSON enclosed in triple backticks
{
"input": "```json\n{\"key1\": \"value1\", \"key2\": 2}\n```",
"expected": {"key1": "value1", "key2": 2}
},
# Case 2: JSON directly in the response
{
"input": "{\"key1\": \"value1\", \"key2\": 2}",
"expected": {"key1": "value1", "key2": 2}
},
# Case 3: Response contains multiple JSON objects
{
"input": "Random text {\"key1\": \"value1\"} and another {\"key2\": 2}",
"expected": [{"key1": "value1"}, {"key2": 2}]
},
# Case 4: Response contains incomplete JSON
{
"input": "Random text {\"key1\": \"value1\"} and another {\"key2\": 2",
"expected": {"key1": "value1"}
}

]

for i, case in enumerate(test_cases):
with self.subTest(i=i):
result = extract_json_from_response(case["input"])
self.assertEqual(result, case["expected"])

def test_extract(self):
results, total_tokens_used = extract(
chunks=self.chunks,
Expand Down
30 changes: 21 additions & 9 deletions tests/test_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def test_scrape_ipynb(self):
self.assertTrue(any(len(chunk.images) > 0 for chunk in chunks))

# requires modal token to run
#def test_scrape_pdf_with_ai_extraction(self):
# chunks = scraper.scrape_file("tests/files/example.pdf", ai_extraction=True, verbose=True, local=True)
# # verify it scraped the pdf file into chunks
# self.assertEqual(type(chunks), list)
# self.assertNotEqual(len(chunks), 0)
# self.assertEqual(type(chunks[0]), core.Chunk)
# # verify it scraped the data
# for chunk in chunks:
# self.assertIsNotNone(chunk.texts or chunk.images)
def test_scrape_pdf_with_ai_extraction(self):
chunks = scraper.scrape_file("tests/files/example.pdf", ai_extraction=True, verbose=True, local=True)
# verify it scraped the pdf file into chunks
self.assertEqual(type(chunks), list)
self.assertNotEqual(len(chunks), 0)
self.assertEqual(type(chunks[0]), core.Chunk)
# verify it scraped the data
for chunk in chunks:
self.assertIsNotNone(chunk.texts or chunk.images)

def test_scrape_docx(self):
chunks = scraper.scrape_file(self.files_directory+"/example.docx", verbose=True, local=True)
Expand Down Expand Up @@ -148,6 +148,18 @@ def test_scrape_url(self):
chunks = scraper.scrape_url('https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf', local=True)
self.assertEqual(len(chunks), 1)

def test_scrape_url_with_ai_extraction(self):
# verify web page scrape result with ai extraction
chunks = scraper.scrape_url('https://en.wikipedia.org/wiki/Piping', ai_extraction=True, local=True)
for chunk in chunks:
self.assertEqual(type(chunk), core.Chunk)
self.assertEqual(chunk.path, 'https://en.wikipedia.org/wiki/Piping')
# assert if any of the texts in chunk.texts contains 'pipe'
self.assertGreater(len(chunk.texts), 0)
self.assertIn('pipe', chunk.texts[0])
# verify if at least one image was scraped
self.assertTrue(any(len(chunk.images) > 0 for chunk in chunks))

@unittest.skipUnless(os.environ.get('GITHUB_TOKEN'), "requires GITHUB_TOKEN")
def test_scrape_github(self):
chunks = scraper.scrape_url('https://github.com/emcf/thepipe', local=True)
Expand Down
5 changes: 3 additions & 2 deletions thepipe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from openai import OpenAI

DEFAULT_EXTRACTION_PROMPT = "Extract structured information from the above document according to the following schema: {schema}. Immediately return valid JSON formatted data. If there is missing data, you may use null, but use your reasoning to always fill in every column as best you can. Always immediately return valid JSON."
DEFAULT_AI_MODEL = os.getenv("DEFAULT_AI_MODEL", "gpt-4o-mini")

def extract_json_from_response(llm_response: str) -> Optional[Dict]:
def extract_json_from_response(llm_response: str) -> Union[Dict, List[Dict], None]:
def clean_response_text(llm_response: str) -> str:
return llm_response.encode('utf-8', 'ignore').decode('utf-8')

Expand Down Expand Up @@ -99,7 +100,7 @@ def extract_from_chunk(chunk: Chunk, chunk_index: int, schema: str, ai_model: st
response_dict = {"chunk_index": chunk_index, "source": source, "error": str(e)}
return response_dict, tokens_used

def extract(chunks: List[Chunk], schema: Union[str, Dict], ai_model: str = 'google/gemma-2-9b-it', multiple_extractions: bool = False, extraction_prompt: str = DEFAULT_EXTRACTION_PROMPT, host_images: bool = False) -> Tuple[List[Dict], int]:
def extract(chunks: List[Chunk], schema: Union[str, Dict], ai_model: Optional[str] = 'google/gemma-2-9b-it', multiple_extractions: Optional[bool] = False, extraction_prompt: Optional[str] = DEFAULT_EXTRACTION_PROMPT, host_images: Optional[bool] = False) -> Tuple[List[Dict], int]:
if isinstance(schema, dict):
schema = json.dumps(schema)

Expand Down
14 changes: 7 additions & 7 deletions thepipe/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
EXTRACTION_PROMPT = os.getenv("EXTRACTION_PROMPT", """An open source document is given. Output the entire extracted contents from the document in detailed markdown format.
Be sure to correctly format markdown for headers, paragraphs, lists, tables, menus, equations, full text contents, etc.
Always reply immediately with only markdown. Do not output anything else.""")
DEFAULT_VLM = os.getenv("DEFAULT_VLM", "gpt-4o-mini")
DEFAULT_AI_MODEL = os.getenv("DEFAULT_AI_MODEL", "gpt-4o-mini")
FILESIZE_LIMIT_MB = os.getenv("FILESIZE_LIMIT_MB", 50)

def detect_source_type(source: str) -> str:
Expand All @@ -55,7 +55,7 @@ def detect_source_type(source: str) -> str:
mimetype = result.output.mime_type
return mimetype

def scrape_file(filepath: str, ai_extraction: bool = False, text_only: bool = False, verbose: bool = False, local: bool = False, chunking_method: Optional[Callable] = chunk_by_page) -> List[Chunk]:
def scrape_file(filepath: str, ai_extraction: bool = False, text_only: bool = False, verbose: bool = False, local: bool = False, chunking_method: Optional[Callable] = chunk_by_page, ai_model: Optional[str] = DEFAULT_AI_MODEL) -> List[Chunk]:
if not local:
with open(filepath, 'rb') as f:
response = requests.post(
Expand Down Expand Up @@ -92,7 +92,7 @@ def scrape_file(filepath: str, ai_extraction: bool = False, text_only: bool = Fa
if verbose:
print(f"[thepipe] Scraping {source_type}: {filepath}...")
if source_type == 'application/pdf':
scraped_chunks = scrape_pdf(file_path=filepath, ai_extraction=ai_extraction, text_only=text_only, verbose=verbose)
scraped_chunks = scrape_pdf(file_path=filepath, ai_extraction=ai_extraction, text_only=text_only, verbose=verbose, ai_model=ai_model)
elif source_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
scraped_chunks = scrape_docx(file_path=filepath, verbose=verbose, text_only=text_only)
elif source_type == 'application/vnd.openxmlformats-officedocument.presentationml.presentation':
Expand Down Expand Up @@ -149,7 +149,7 @@ def scrape_zip(file_path: str, include_regex: Optional[str] = None, verbose: boo
chunks = scrape_directory(dir_path=temp_dir, include_regex=include_regex, verbose=verbose, ai_extraction=ai_extraction, text_only=text_only, local=local)
return chunks

def scrape_pdf(file_path: str, ai_extraction: bool = False, text_only: bool = False, verbose: bool = False) -> List[Chunk]:
def scrape_pdf(file_path: str, ai_extraction: Optional[bool] = False, text_only: Optional[bool] = False, ai_model: Optional[str] = DEFAULT_AI_MODEL, verbose: Optional[bool] = False) -> List[Chunk]:
chunks = []
MAX_PAGES = 128

Expand Down Expand Up @@ -188,7 +188,7 @@ def process_page(page_num):
},
]
response = openrouter_client.chat.completions.create(
model=DEFAULT_VLM,
model=ai_model,
messages=messages,
temperature=0.2
)
Expand Down Expand Up @@ -289,7 +289,7 @@ def scrape_spreadsheet(file_path: str, source_type: str) -> List[Chunk]:
chunks.append(Chunk(path=file_path, texts=[item_json]))
return chunks

def ai_extract_webpage_content(url: str, text_only: bool = False, verbose: bool = False) -> Chunk:
def ai_extract_webpage_content(url: str, text_only: Optional[bool] = False, verbose: Optional[bool] = False, ai_model: Optional[str] = DEFAULT_AI_MODEL) -> Chunk:
from playwright.sync_api import sync_playwright
import modal
from openai import OpenAI
Expand Down Expand Up @@ -352,7 +352,7 @@ def ai_extract_webpage_content(url: str, text_only: bool = False, verbose: bool
},
]
response = openrouter_client.chat.completions.create(
model=DEFAULT_VLM,
model=ai_model,
messages=messages,
temperature=0.2
)
Expand Down

0 comments on commit 535e45c

Please sign in to comment.