diff --git a/thepipe/extract.py b/thepipe/extract.py index 35ec192..7795a86 100644 --- a/thepipe/extract.py +++ b/thepipe/extract.py @@ -101,7 +101,10 @@ def extract_from_chunk(chunk: Chunk, chunk_index: int, schema: str, ai_model: st return response_dict, tokens_used -def extract(chunks: List[Chunk], schema: str, ai_model: str = 'google/gemma-2-9b-it', multiple_extractions: bool = False, extraction_prompt: str = DEFAULT_EXTRACTION_PROMPT, host_images: bool = False) -> Tuple[List[Dict], int]: +def extract(chunks: List[Chunk], schema: Union[str, Dict], ai_model: str = 'google/gemma-2-9b-it', multiple_extractions: bool = False, extraction_prompt: str = DEFAULT_EXTRACTION_PROMPT, host_images: bool = False) -> Tuple[List[Dict], int]: + if isinstance(schema, dict): + schema = json.dumps(schema) + results = [] total_tokens_used = 0 @@ -148,8 +151,6 @@ def extract_from_url( chunking_method: Callable[[List[Chunk]], List[Chunk]] = chunk_by_document, local: bool = False ) -> List[Dict]: #Tuple[List[Dict], int]: - if isinstance(schema, dict): - schema = json.dumps(schema) if local: chunks = scrape_url(url, text_only=text_only, ai_extraction=ai_extraction, verbose=verbose, local=local) chunked_content = chunking_method(chunks) @@ -212,8 +213,6 @@ def extract_from_file( chunking_method: Callable[[List[Chunk]], List[Chunk]] = chunk_by_document, local: bool = False ) -> List[Dict]:#Tuple[List[Dict], int]: - if isinstance(schema, dict): - schema = json.dumps(schema) if local: chunks = scrape_file(file_path, ai_extraction=ai_extraction, text_only=text_only, verbose=verbose) chunked_content = chunking_method(chunks)