From cf9f65bbcf2ba89be3d7c3a5927d7f21c6422046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E5=B9=B3?= Date: Sat, 21 Sep 2024 16:26:18 +0800 Subject: [PATCH] =?UTF-8?q?0.3.1=E7=89=88=E6=9C=ACgraphrag=E9=80=82?= =?UTF-8?q?=E9=85=8D=EF=BC=9B=20=E5=AF=B9=E9=83=A8=E5=88=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E9=80=BB=E8=BE=91=E5=81=9A=E4=BA=86=E4=BC=98=E5=8C=96?= =?UTF-8?q?=EF=BC=8C=E5=8E=BB=E9=99=A4=E4=BA=86=E4=B8=80=E4=BA=9B=E5=86=97?= =?UTF-8?q?=E4=BD=99=E4=BB=A3=E7=A0=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 +- graphrag_api/index.py | 258 ++++++++++++++++++----------------------- graphrag_api/search.py | 72 ++++++++++-- requirements.txt | 2 +- setup.py | 2 +- tests/search_test.py | 4 +- 6 files changed, 178 insertions(+), 164 deletions(-) diff --git a/README.md b/README.md index 8ec7199..2746f5f 100644 --- a/README.md +++ b/README.md @@ -146,8 +146,8 @@ from graphrag_api.search import SearchRunner search_runner = SearchRunner(root_dir="rag") -search_runner.run_local_search(query="What are the top themes in this story?") -search_runner.run_global_search(query="What are the top themes in this story?") +search_runner.run_local_search(query="What are the top themes in this story?", streaming=False) +search_runner.run_global_search(query="What are the top themes in this story?", streaming=False) # 对于输出的结果可能带有一些特殊字符,可以采用以下函数去除特殊字符或自行处理。 search_runner.remove_sources(search_runner.run_local_search(query="What are the top themes in this story?")) diff --git a/graphrag_api/index.py b/graphrag_api/index.py index 95eb972..36d948f 100644 --- a/graphrag_api/index.py +++ b/graphrag_api/index.py @@ -8,24 +8,22 @@ import asyncio import json import logging -import platform import sys import time import warnings from pathlib import Path -from typing import Optional, Union - -from graphrag.index import PipelineConfig, create_pipeline_config -from graphrag.index.cache import NoopPipelineCache -from graphrag.index.progress import ( - NullProgressReporter, - PrintProgressReporter, - ProgressReporter, -) -from graphrag.index.progress.rich import RichProgressReporter -from graphrag.index.run import run_pipeline_with_config +from typing import Optional + +from graphrag.config import create_graphrag_config +from graphrag.config.logging import enable_logging_with_config +from graphrag.config.enums import CacheType +from graphrag.config.config_file_loader import load_config_from_file, resolve_config_path_with_root +from graphrag.index.validate_config import validate_config_names +from graphrag.index.api import build_index +from graphrag.index.progress import ProgressReporter + +from graphrag.index.progress.load_progress_reporter import load_progress_reporter -from graphrag.index.emit import TableEmitterType from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT from graphrag.index.graph.extractors.community_reports.prompts import ( COMMUNITY_REPORT_PROMPT, @@ -56,6 +54,7 @@ def __init__( dryrun: bool = False, init: bool = False, overlay_defaults: bool = False, + skip_validations: bool = False ): self.root = root self.verbose = verbose @@ -68,8 +67,46 @@ def __init__( self.dryrun = dryrun self.init = init self.overlay_defaults = overlay_defaults + self.skip_validations = skip_validations self.cli = False + @staticmethod + def register_signal_handlers(reporter: ProgressReporter): + import signal + + def handle_signal(signum, _): + # Handle the signal here + reporter.info(f"Received signal {signum}, exiting...") + reporter.dispose() + for task in asyncio.all_tasks(): + task.cancel() + reporter.info("All tasks cancelled. Exiting...") + + # Register signal handlers for SIGINT and SIGHUP + signal.signal(signal.SIGINT, handle_signal) + + if sys.platform != "win32": + signal.signal(signal.SIGHUP, handle_signal) + + @staticmethod + def logger(reporter: ProgressReporter): + def info(msg: str, verbose: bool = False): + log.info(msg) + if verbose: + reporter.info(msg) + + def error(msg: str, verbose: bool = False): + log.error(msg) + if verbose: + reporter.error(msg) + + def success(msg: str, verbose: bool = False): + log.info(msg) + if verbose: + reporter.success(msg) + + return info, error, success + @staticmethod def redact(input: dict) -> str: """Sanitize the config json.""" @@ -87,7 +124,7 @@ def redact_dict(input: dict) -> dict: "organization", }: if value is not None: - result[key] = f"REDACTED, length {len(value)}" + result[key] = "==== REDACTED ====" elif isinstance(value, dict): result[key] = redact_dict(value) elif isinstance(value, list): @@ -101,98 +138,79 @@ def redact_dict(input: dict) -> dict: def run(self): """Run the pipeline with the given config.""" + progress_reporter = load_progress_reporter(self.reporter or "rich") + info, error, success = self.logger(progress_reporter) run_id = self.resume or time.strftime("%Y%m%d-%H%M%S") - self._enable_logging(self.root, run_id, self.verbose) - progress_reporter = self._get_progress_reporter(self.reporter) + if self.init: self._initialize_project_at(self.root, progress_reporter) sys.exit(0) - if self.overlay_defaults: - pipeline_config: Union[str, PipelineConfig] = self._create_default_config( - self.root, - self.config, - self.verbose, - self.dryrun or False, - progress_reporter, + + if self.overlay_defaults or self.config: + config_path = ( + Path(self.root) / self.config if self.config else resolve_config_path_with_root(self.root) ) + default_config = load_config_from_file(config_path) + else: + try: + config_path = resolve_config_path_with_root(self.root) + default_config = load_config_from_file(config_path) + except FileNotFoundError: + default_config = create_graphrag_config(root_dir=self.root) + + if self.nocache: + default_config.cache.type = CacheType.none + + enabled_logging, log_path = enable_logging_with_config( + default_config, run_id, self.verbose + ) + if enabled_logging: + info(f"Logging enabled at {log_path}", True) else: - pipeline_config: Union[ - str, PipelineConfig - ] = self.config or self._create_default_config( - self.root, None, self.verbose, self.dryrun or False, progress_reporter + info( + f"Logging not enabled for config {self.redact(default_config.model_dump())}", + True, ) - cache = NoopPipelineCache() if self.nocache else None + + if self.skip_validations: + validate_config_names(progress_reporter, default_config) + info(f"Starting pipeline run for: {run_id}, {self.dryrun=}", self.verbose) + info( + f"Using default configuration: {self.redact(default_config.model_dump())}", + self.verbose, + ) + + if self.dryrun: + info("Dry run complete, exiting...", True) + sys.exit(0) + pipeline_emit = self.emit.split(",") if self.emit else None - encountered_errors = False - - def _run_workflow_async() -> None: - import signal - - def handle_signal(signum, _): - # Handle the signal here - progress_reporter.info(f"Received signal {signum}, exiting...") - progress_reporter.dispose() - for task in asyncio.all_tasks(): - task.cancel() - progress_reporter.info("All tasks cancelled. Exiting...") - - # Register signal handlers for SIGINT and SIGHUP - signal.signal(signal.SIGINT, handle_signal) - - if sys.platform != "win32": - signal.signal(signal.SIGHUP, handle_signal) - - async def execute(): - nonlocal encountered_errors - async for output in run_pipeline_with_config( - pipeline_config, - run_id=run_id, - memory_profile=self.memprofile, - cache=cache, - progress_reporter=progress_reporter, - emit=( - [TableEmitterType(e) for e in pipeline_emit] - if pipeline_emit - else None - ), - is_resume_run=bool(self.resume), - ): - if output.errors and len(output.errors) > 0: - encountered_errors = True - progress_reporter.error(output.workflow) - else: - progress_reporter.success(output.workflow) - - progress_reporter.info(str(output.result)) - - if platform.system() == "Windows": - import nest_asyncio # type: ignore Ignoring because out of windows this will cause an error - - nest_asyncio.apply() - loop = asyncio.get_event_loop() - loop.run_until_complete(execute()) - elif sys.version_info >= (3, 11): - import uvloop # type: ignore Ignoring because on windows this will cause an error - - with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: # type: ignore Ignoring because minor versions this will throw an error - runner.run(execute()) - else: - import uvloop # type: ignore Ignoring because on windows this will cause an error - - uvloop.install() - asyncio.run(execute()) - - _run_workflow_async() + + self.register_signal_handlers(progress_reporter) + + outputs = asyncio.run( + build_index( + default_config, + run_id, + self.memprofile, + progress_reporter, + pipeline_emit + ) + ) + + encountered_errors = any( + output.errors and len(output.errors) > 0 for output in outputs + ) + progress_reporter.stop() if encountered_errors: - progress_reporter.error( - "Errors occurred during the pipeline run, see logs for more details." + error( + "Errors occurred during the pipeline run, see logs for more details.", True ) else: - progress_reporter.success("All workflows completed successfully.") + success("All workflows completed successfully.", True) - if self.cli: - sys.exit(1 if encountered_errors else 0) + sys.exit(1 if encountered_errors else 0) @staticmethod def _initialize_project_at(path: str, reporter: ProgressReporter) -> None: @@ -245,42 +263,6 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None: COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict") ) - def _create_default_config( - self, - root: str, - config: Optional[str], - verbose: bool, - dryrun: bool, - reporter: ProgressReporter, - ) -> PipelineConfig: - """Overlay default values on an existing config or create a default config if none is provided.""" - if config and not Path(config).exists(): - msg = f"Configuration file {config} does not exist" - raise ValueError(msg) - - if not Path(root).exists(): - msg = f"Root directory {root} does not exist" - raise ValueError(msg) - - parameters = self._read_config_parameters(root, config, reporter) - log.info( - "using default configuration: %s", - self.redact(parameters.model_dump()), - ) - - if verbose or dryrun: - reporter.info( - f"Using default configuration: {self.redact(parameters.model_dump())}" - ) - result = create_pipeline_config(parameters, verbose) - if verbose or dryrun: - reporter.info(f"Final Config: {self.redact(result.model_dump())}") - - if dryrun: - reporter.info("dry run complete, exiting...") - sys.exit(0) - return result - @staticmethod def _enable_logging(root: str, run_id: str, verbose: bool) -> None: """Enable logging to file and console.""" @@ -297,21 +279,3 @@ def _enable_logging(root: str, run_id: str, verbose: bool) -> None: level=logging.DEBUG if verbose else logging.INFO, ) - @staticmethod - def _get_progress_reporter( - progress_reporter: Optional[str] = None, - ) -> ProgressReporter: - """Enable progress reporting to console.""" - _reporter = progress_reporter or "print" - if _reporter == "null": - return NullProgressReporter() - elif _reporter == "print": - return PrintProgressReporter("GraphRAG Indexer ") - elif _reporter == "rich": - return RichProgressReporter("GraphRAG Indexer ") - else: - msg = ( - f"Unsupported progress reporter: {_reporter}. " - f"Supported reporters are null, print and rich" - ) - raise ValueError(msg) diff --git a/graphrag_api/search.py b/graphrag_api/search.py index 4ce291a..58cde97 100644 --- a/graphrag_api/search.py +++ b/graphrag_api/search.py @@ -4,11 +4,13 @@ 基于 graphrag\\query\\cli.py 修改 """ - -import os +import asyncio import re +import sys from pathlib import Path from typing import cast +from pydantic import validate_call +from collections.abc import AsyncGenerator import pandas as pd @@ -29,6 +31,7 @@ read_indexer_reports, read_indexer_text_units, ) +from graphrag.query.api import _reformat_context_data from graphrag_api.common import BaseGraph @@ -53,6 +56,45 @@ def __init__( self.__local_agent = self.__get__local_agent() self.__global_agent = self.__get__global_agent() + @staticmethod + @validate_call(config={"arbitrary_types_allowed": True}) + async def search(search_agent, query): + """非流式搜索""" + result = await search_agent.asearch(query=query) + return result.response + + @staticmethod + @validate_call(config={"arbitrary_types_allowed": True}) + async def search_streaming(search_agent, query) -> AsyncGenerator: + """流式搜索""" + search_result = search_agent.astream_search(query=query) + context_data = None + get_context_data = True + async for stream_chunk in search_result: + if get_context_data: + context_data = _reformat_context_data(stream_chunk) + yield context_data + get_context_data = False + else: + yield stream_chunk + + async def run_streaming_search(self, search_agent, query): + full_response = "" + context_data = None + get_context_data = True + async for stream_chunk in self.search_streaming( + search_agent=search_agent, query=query + ): + if get_context_data: + context_data = stream_chunk + get_context_data = False + else: + full_response += stream_chunk + print(stream_chunk, end="") # noqa: T201 + sys.stdout.flush() # flush output buffer to display text immediately + print() # noqa: T201 + return full_response, context_data + @staticmethod def __get_embedding_description_store( entities: list[Entity], @@ -130,13 +172,15 @@ def __get__global_agent(self): response_type=self.response_type, ) - def run_global_search(self, query): + def run_global_search(self, query, streaming=False): """Run a global search with the given query.""" - result = self.__global_agent.search(query=query) + if streaming: + return asyncio.run( + self.run_streaming_search(search_agent=self.__global_agent, query=query) + ) - reporter.success(f"Global Search Response: {result.response}") - return result.response + return asyncio.run(self.search(search_agent=self.__global_agent, query=query)) def __get__local_agent(self): """获取local搜索引擎""" @@ -197,12 +241,14 @@ def __get__local_agent(self): response_type=self.response_type, ) - def run_local_search(self, query): + def run_local_search(self, query, streaming=False): """Run a local search with the given query.""" - result = self.__local_agent.search(query=query) - reporter.success(f"Local Search Response: {result.response}") - return result.response + if streaming: + return asyncio.run( + self.run_streaming_search(search_agent=self.__local_agent, query=query) + ) + return asyncio.run(self.search(search_agent=self.__local_agent, query=query)) def _configure_paths_and_settings( self, @@ -223,7 +269,11 @@ def _infer_data_dir(root: str) -> str: output = Path(root) / "output" # use the latest data-run folder if output.exists(): - folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True) + expr = re.compile(r"\d{8}-\d{6}") + filtered = [ + f for f in output.iterdir() if f.is_dir() and expr.match(f.name) + ] + folders = sorted(filtered, key=lambda f: f.name, reverse=True) if len(folders) > 0: folder = folders[0] return str((folder / "artifacts").absolute()) diff --git a/requirements.txt b/requirements.txt index 0eb07dd..b304f02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -graphrag==0.3.0 +graphrag==0.3.1 diff --git a/setup.py b/setup.py index 0588cb4..0811634 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name="graphrag_api", - version="0.3.0", + version="0.3.1", packages=find_packages(exclude=["tests"]), install_requires=requirements, author="nightzjp", diff --git a/tests/search_test.py b/tests/search_test.py index 726f4f3..bae09ad 100644 --- a/tests/search_test.py +++ b/tests/search_test.py @@ -88,8 +88,8 @@ def __str__(self): match args.method: case SearchType.LOCAL: - search_runner.run_local_search(query=args.query[0]) + search_runner.run_local_search(query=args.query[0], streaming=False) case SearchType.GLOBAL: - search_runner.run_global_search(query=args.query[0]) + search_runner.run_global_search(query=args.query[0], streaming=False) case _: raise ValueError(INVALID_METHOD_ERROR)