Skip to content

Commit

Permalink
0.3.1版本graphrag适配;
Browse files Browse the repository at this point in the history
对部分代码逻辑做了优化,去除了一些冗余代码。
  • Loading branch information
张建平 committed Sep 21, 2024
1 parent e37e977 commit cf9f65b
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 164 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ from graphrag_api.search import SearchRunner

search_runner = SearchRunner(root_dir="rag")

search_runner.run_local_search(query="What are the top themes in this story?")
search_runner.run_global_search(query="What are the top themes in this story?")
search_runner.run_local_search(query="What are the top themes in this story?", streaming=False)
search_runner.run_global_search(query="What are the top themes in this story?", streaming=False)

# 对于输出的结果可能带有一些特殊字符,可以采用以下函数去除特殊字符或自行处理。
search_runner.remove_sources(search_runner.run_local_search(query="What are the top themes in this story?"))
Expand Down
258 changes: 111 additions & 147 deletions graphrag_api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,22 @@
import asyncio
import json
import logging
import platform
import sys
import time
import warnings
from pathlib import Path
from typing import Optional, Union

from graphrag.index import PipelineConfig, create_pipeline_config
from graphrag.index.cache import NoopPipelineCache
from graphrag.index.progress import (
NullProgressReporter,
PrintProgressReporter,
ProgressReporter,
)
from graphrag.index.progress.rich import RichProgressReporter
from graphrag.index.run import run_pipeline_with_config
from typing import Optional

from graphrag.config import create_graphrag_config
from graphrag.config.logging import enable_logging_with_config
from graphrag.config.enums import CacheType
from graphrag.config.config_file_loader import load_config_from_file, resolve_config_path_with_root
from graphrag.index.validate_config import validate_config_names
from graphrag.index.api import build_index
from graphrag.index.progress import ProgressReporter

from graphrag.index.progress.load_progress_reporter import load_progress_reporter

from graphrag.index.emit import TableEmitterType
from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
from graphrag.index.graph.extractors.community_reports.prompts import (
COMMUNITY_REPORT_PROMPT,
Expand Down Expand Up @@ -56,6 +54,7 @@ def __init__(
dryrun: bool = False,
init: bool = False,
overlay_defaults: bool = False,
skip_validations: bool = False
):
self.root = root
self.verbose = verbose
Expand All @@ -68,8 +67,46 @@ def __init__(
self.dryrun = dryrun
self.init = init
self.overlay_defaults = overlay_defaults
self.skip_validations = skip_validations
self.cli = False

@staticmethod
def register_signal_handlers(reporter: ProgressReporter):
import signal

def handle_signal(signum, _):
# Handle the signal here
reporter.info(f"Received signal {signum}, exiting...")
reporter.dispose()
for task in asyncio.all_tasks():
task.cancel()
reporter.info("All tasks cancelled. Exiting...")

# Register signal handlers for SIGINT and SIGHUP
signal.signal(signal.SIGINT, handle_signal)

if sys.platform != "win32":
signal.signal(signal.SIGHUP, handle_signal)

@staticmethod
def logger(reporter: ProgressReporter):
def info(msg: str, verbose: bool = False):
log.info(msg)
if verbose:
reporter.info(msg)

def error(msg: str, verbose: bool = False):
log.error(msg)
if verbose:
reporter.error(msg)

def success(msg: str, verbose: bool = False):
log.info(msg)
if verbose:
reporter.success(msg)

return info, error, success

@staticmethod
def redact(input: dict) -> str:
"""Sanitize the config json."""
Expand All @@ -87,7 +124,7 @@ def redact_dict(input: dict) -> dict:
"organization",
}:
if value is not None:
result[key] = f"REDACTED, length {len(value)}"
result[key] = "==== REDACTED ===="
elif isinstance(value, dict):
result[key] = redact_dict(value)
elif isinstance(value, list):
Expand All @@ -101,98 +138,79 @@ def redact_dict(input: dict) -> dict:

def run(self):
"""Run the pipeline with the given config."""
progress_reporter = load_progress_reporter(self.reporter or "rich")
info, error, success = self.logger(progress_reporter)
run_id = self.resume or time.strftime("%Y%m%d-%H%M%S")
self._enable_logging(self.root, run_id, self.verbose)
progress_reporter = self._get_progress_reporter(self.reporter)

if self.init:
self._initialize_project_at(self.root, progress_reporter)
sys.exit(0)
if self.overlay_defaults:
pipeline_config: Union[str, PipelineConfig] = self._create_default_config(
self.root,
self.config,
self.verbose,
self.dryrun or False,
progress_reporter,

if self.overlay_defaults or self.config:
config_path = (
Path(self.root) / self.config if self.config else resolve_config_path_with_root(self.root)
)
default_config = load_config_from_file(config_path)
else:
try:
config_path = resolve_config_path_with_root(self.root)
default_config = load_config_from_file(config_path)
except FileNotFoundError:
default_config = create_graphrag_config(root_dir=self.root)

if self.nocache:
default_config.cache.type = CacheType.none

enabled_logging, log_path = enable_logging_with_config(
default_config, run_id, self.verbose
)
if enabled_logging:
info(f"Logging enabled at {log_path}", True)
else:
pipeline_config: Union[
str, PipelineConfig
] = self.config or self._create_default_config(
self.root, None, self.verbose, self.dryrun or False, progress_reporter
info(
f"Logging not enabled for config {self.redact(default_config.model_dump())}",
True,
)
cache = NoopPipelineCache() if self.nocache else None

if self.skip_validations:
validate_config_names(progress_reporter, default_config)
info(f"Starting pipeline run for: {run_id}, {self.dryrun=}", self.verbose)
info(
f"Using default configuration: {self.redact(default_config.model_dump())}",
self.verbose,
)

if self.dryrun:
info("Dry run complete, exiting...", True)
sys.exit(0)

pipeline_emit = self.emit.split(",") if self.emit else None
encountered_errors = False

def _run_workflow_async() -> None:
import signal

def handle_signal(signum, _):
# Handle the signal here
progress_reporter.info(f"Received signal {signum}, exiting...")
progress_reporter.dispose()
for task in asyncio.all_tasks():
task.cancel()
progress_reporter.info("All tasks cancelled. Exiting...")

# Register signal handlers for SIGINT and SIGHUP
signal.signal(signal.SIGINT, handle_signal)

if sys.platform != "win32":
signal.signal(signal.SIGHUP, handle_signal)

async def execute():
nonlocal encountered_errors
async for output in run_pipeline_with_config(
pipeline_config,
run_id=run_id,
memory_profile=self.memprofile,
cache=cache,
progress_reporter=progress_reporter,
emit=(
[TableEmitterType(e) for e in pipeline_emit]
if pipeline_emit
else None
),
is_resume_run=bool(self.resume),
):
if output.errors and len(output.errors) > 0:
encountered_errors = True
progress_reporter.error(output.workflow)
else:
progress_reporter.success(output.workflow)

progress_reporter.info(str(output.result))

if platform.system() == "Windows":
import nest_asyncio # type: ignore Ignoring because out of windows this will cause an error

nest_asyncio.apply()
loop = asyncio.get_event_loop()
loop.run_until_complete(execute())
elif sys.version_info >= (3, 11):
import uvloop # type: ignore Ignoring because on windows this will cause an error

with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: # type: ignore Ignoring because minor versions this will throw an error
runner.run(execute())
else:
import uvloop # type: ignore Ignoring because on windows this will cause an error

uvloop.install()
asyncio.run(execute())

_run_workflow_async()

self.register_signal_handlers(progress_reporter)

outputs = asyncio.run(
build_index(
default_config,
run_id,
self.memprofile,
progress_reporter,
pipeline_emit
)
)

encountered_errors = any(
output.errors and len(output.errors) > 0 for output in outputs
)

progress_reporter.stop()
if encountered_errors:
progress_reporter.error(
"Errors occurred during the pipeline run, see logs for more details."
error(
"Errors occurred during the pipeline run, see logs for more details.", True
)
else:
progress_reporter.success("All workflows completed successfully.")
success("All workflows completed successfully.", True)

if self.cli:
sys.exit(1 if encountered_errors else 0)
sys.exit(1 if encountered_errors else 0)

@staticmethod
def _initialize_project_at(path: str, reporter: ProgressReporter) -> None:
Expand Down Expand Up @@ -245,42 +263,6 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None:
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
)

def _create_default_config(
self,
root: str,
config: Optional[str],
verbose: bool,
dryrun: bool,
reporter: ProgressReporter,
) -> PipelineConfig:
"""Overlay default values on an existing config or create a default config if none is provided."""
if config and not Path(config).exists():
msg = f"Configuration file {config} does not exist"
raise ValueError(msg)

if not Path(root).exists():
msg = f"Root directory {root} does not exist"
raise ValueError(msg)

parameters = self._read_config_parameters(root, config, reporter)
log.info(
"using default configuration: %s",
self.redact(parameters.model_dump()),
)

if verbose or dryrun:
reporter.info(
f"Using default configuration: {self.redact(parameters.model_dump())}"
)
result = create_pipeline_config(parameters, verbose)
if verbose or dryrun:
reporter.info(f"Final Config: {self.redact(result.model_dump())}")

if dryrun:
reporter.info("dry run complete, exiting...")
sys.exit(0)
return result

@staticmethod
def _enable_logging(root: str, run_id: str, verbose: bool) -> None:
"""Enable logging to file and console."""
Expand All @@ -297,21 +279,3 @@ def _enable_logging(root: str, run_id: str, verbose: bool) -> None:
level=logging.DEBUG if verbose else logging.INFO,
)

@staticmethod
def _get_progress_reporter(
progress_reporter: Optional[str] = None,
) -> ProgressReporter:
"""Enable progress reporting to console."""
_reporter = progress_reporter or "print"
if _reporter == "null":
return NullProgressReporter()
elif _reporter == "print":
return PrintProgressReporter("GraphRAG Indexer ")
elif _reporter == "rich":
return RichProgressReporter("GraphRAG Indexer ")
else:
msg = (
f"Unsupported progress reporter: {_reporter}. "
f"Supported reporters are null, print and rich"
)
raise ValueError(msg)
Loading

0 comments on commit cf9f65b

Please sign in to comment.