Skip to content

Commit

Permalink
0.3.3版本graphrag适配;
Browse files Browse the repository at this point in the history
对部分代码逻辑做了优化,去除了一些冗余代码。
  • Loading branch information
张建平 committed Nov 28, 2024
1 parent 5a6786c commit 1adce00
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 98 deletions.
50 changes: 20 additions & 30 deletions graphrag_api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
from pathlib import Path
from typing import Optional

from graphrag.config import create_graphrag_config
from graphrag.config.logging import enable_logging_with_config
from graphrag.config.enums import CacheType
from graphrag.config.config_file_loader import load_config_from_file, resolve_config_path_with_root
from graphrag.config import CacheType, enable_logging_with_config, load_config

from graphrag.index.validate_config import validate_config_names
from graphrag.index.api import build_index
from graphrag.index.progress import ProgressReporter

from graphrag.index.progress.load_progress_reporter import load_progress_reporter

from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
Expand All @@ -46,27 +43,27 @@ def __init__(
root: str = ".",
verbose: bool = False,
resume: Optional[str] = None,
update_index_id: Optional[str] = None,
memprofile: bool = False,
nocache: bool = False,
reporter: Optional[str] = "",
config: Optional[str] = "",
config_filepath: Optional[str] = "",
emit: Optional[str] = "",
dryrun: bool = False,
init: bool = False,
overlay_defaults: bool = False,
skip_validations: bool = False
):
self.root = root
self.verbose = verbose
self.resume = resume
self.update_index_id = update_index_id
self.memprofile = memprofile
self.nocache = nocache
self.reporter = reporter
self.config = config
self.config_filepath = config_filepath
self.emit = emit
self.dryrun = dryrun
self.init = init
self.overlay_defaults = overlay_defaults
self.skip_validations = skip_validations
self.cli = False

Expand Down Expand Up @@ -146,37 +143,28 @@ def run(self):
self._initialize_project_at(self.root, progress_reporter)
sys.exit(0)

if self.overlay_defaults or self.config:
config_path = (
Path(self.root) / self.config if self.config else resolve_config_path_with_root(self.root)
)
default_config = load_config_from_file(config_path)
else:
try:
config_path = resolve_config_path_with_root(self.root)
default_config = load_config_from_file(config_path)
except FileNotFoundError:
default_config = create_graphrag_config(root_dir=self.root)
root = Path(self.root).resolve()
config = load_config(root, self.config_filepath, run_id)

if self.nocache:
default_config.cache.type = CacheType.none
config.cache.type = CacheType.none

enabled_logging, log_path = enable_logging_with_config(
default_config, run_id, self.verbose
config, self.verbose
)
if enabled_logging:
info(f"Logging enabled at {log_path}", True)
else:
info(
f"Logging not enabled for config {self.redact(default_config.model_dump())}",
f"Logging not enabled for config {self.redact(config.model_dump())}",
True,
)

if self.skip_validations:
validate_config_names(progress_reporter, default_config)
validate_config_names(progress_reporter, config)
info(f"Starting pipeline run for: {run_id}, {self.dryrun=}", self.verbose)
info(
f"Using default configuration: {self.redact(default_config.model_dump())}",
f"Using default configuration: {self.redact(config.model_dump())}",
self.verbose,
)

Expand All @@ -190,11 +178,13 @@ def run(self):

outputs = asyncio.run(
build_index(
default_config,
run_id,
self.memprofile,
progress_reporter,
pipeline_emit
config=config,
run_id=run_id,
is_resume_run=bool(self.resume),
is_update_run=bool(self.update_index_id),
memory_profile=self.memprofile,
progress_reporter=progress_reporter,
emit=pipeline_emit
)
)

Expand Down
145 changes: 85 additions & 60 deletions graphrag_api/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
基于 graphrag\\query\\cli.py 修改
"""
import asyncio
import re
import sys
import asyncio
from pathlib import Path
from typing import cast
from pydantic import validate_call
from collections.abc import AsyncGenerator

import pandas as pd
from pydantic import validate_call

from graphrag.config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.config import GraphRagConfig, load_config, resolve_path
from graphrag.index.progress import PrintProgressReporter
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.utils.storage import _create_storage, _load_table_from_storage
from graphrag.model.entity import Entity
from graphrag.query.input.loaders.dfs import (
store_entity_semantic_embeddings,
Expand Down Expand Up @@ -148,20 +148,28 @@ def __get_embedding_description_store(

def __get__global_agent(self):
"""获取global搜索引擎"""
data_dir, root_dir, config = self._configure_paths_and_settings(
self.data_dir, self.root_dir, self.config_dir
root = Path(self.root_dir).resolve()
config = load_config(root)

if self.data_dir:
config.storage.base_dir = str(resolve_path(self.data_dir, root))

dataframe_dict = self._resolve_parquet_files(
root_dir=self.root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_community_reports.parquet",
],
optional_list=[]
)
data_path = Path(data_dir)

final_nodes: pd.DataFrame = pd.read_parquet(
data_path / "create_final_nodes.parquet"
)
final_entities: pd.DataFrame = pd.read_parquet(
data_path / "create_final_entities.parquet"
)
final_community_reports: pd.DataFrame = pd.read_parquet(
data_path / "create_final_community_reports.parquet"
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]

reports = read_indexer_reports(
final_community_reports, final_nodes, self.community_level
Expand All @@ -188,28 +196,33 @@ def run_global_search(self, query, streaming=False):

def __get__local_agent(self):
"""获取local搜索引擎"""
data_dir, root_dir, config = self._configure_paths_and_settings(
self.data_dir, self.root_dir, self.config_dir
root = Path(self.root_dir).resolve()
config = load_config(root)

if self.data_dir:
config.storage.base_dir = str(resolve_path(self.data_dir, root))

dataframe_dict = self._resolve_parquet_files(
root_dir=self.root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"]
)
data_path = Path(data_dir)

final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
final_community_reports = pd.read_parquet(
data_path / "create_final_community_reports.parquet"
)
final_text_units = pd.read_parquet(
data_path / "create_final_text_units.parquet"
)
final_relationships = pd.read_parquet(
data_path / "create_final_relationships.parquet"
)
final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
final_covariates_path = data_path / "create_final_covariates.parquet"
final_covariates = (
pd.read_parquet(final_covariates_path)
if final_covariates_path.exists()
else None
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_covariates: pd.DataFrame | None = dataframe_dict["create_final_covariates"]

vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
Expand All @@ -221,20 +234,16 @@ def __get__local_agent(self):
entities = read_indexer_entities(
final_nodes, final_entities, self.community_level
)
base_dir = Path(str(root_dir)) / config.storage.base_dir
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"

lancedb_dir = Path(config.storage.base_dir) / "lancedb"

vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = self.__get_embedding_description_store(
entities=entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
covariates = (
read_indexer_covariates(final_covariates)
if final_covariates is not None
else []
)
covariates = read_indexer_covariates(final_covariates) if final_covariates is not None else []

return get_local_search_engine(
config,
Expand All @@ -258,21 +267,6 @@ def run_local_search(self, query, streaming=False):
)
return asyncio.run(self.search(search_agent=self.__local_agent, query=query))

def _configure_paths_and_settings(
self,
data_dir: str | None,
root_dir: str | None,
config_dir: str | None,
) -> tuple[str, str | None, GraphRagConfig]:
config = self._create_graphrag_config(root_dir, config_dir)
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
base_dir = Path(str(root_dir)) / config.storage.base_dir
data_dir = str(resolve_timestamp_path(base_dir))
return data_dir, root_dir, config

@staticmethod
def _infer_data_dir(root: str) -> str:
output = Path(root) / "output"
Expand All @@ -289,6 +283,37 @@ def _infer_data_dir(root: str) -> str:
msg = f"Could not infer data directory from root={root}"
raise ValueError(msg)

@staticmethod
def _resolve_parquet_files(
root_dir: str,
config: GraphRagConfig,
parquet_list: list[str],
optional_list: list[str],
) -> dict[str, pd.DataFrame]:
"""Read parquet files to a dataframe dict."""
dataframe_dict = {}
pipeline_config = create_pipeline_config(config)
storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage)
for parquet_file in parquet_list:
df_key = parquet_file.split(".")[0]
df_value = asyncio.run(
_load_table_from_storage(name=parquet_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value

# for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
for optional_file in optional_list:
file_exists = asyncio.run(storage_obj.has(optional_file))
df_key = optional_file.split(".")[0]
if file_exists:
df_value = asyncio.run(
_load_table_from_storage(name=optional_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
else:
dataframe_dict[df_key] = None
return dataframe_dict

def _create_graphrag_config(
self,
root: str | None,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
graphrag==0.3.2
graphrag==0.3.3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name="graphrag_api",
version="0.3.2",
version="0.3.3",
packages=find_packages(exclude=["tests"]),
install_requires=requirements,
author="nightzjp",
Expand Down
21 changes: 15 additions & 6 deletions tests/index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,23 @@
action="store_true",
)
parser.add_argument(
"--overlay-defaults",
help="Overlay default configuration values on a provided configuration file (--config).",
"--skip-validations",
help="Skip any preflight validation. Useful when running no LLM steps.",
action="store_true",
)
parser.add_argument(
"--update-index",
help="Update a given index run id, leveraging previous outputs and applying new indexes.",
# Only required if config is not defined
required=False,
default=None,
type=str,
)
args = parser.parse_args()

if args.overlay_defaults and not args.config:
parser.error("--overlay-defaults requires --config")
if args.resume and args.update_index:
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)

indexer = GraphRagIndexer(
root=args.root,
Expand All @@ -78,11 +87,11 @@
memprofile=args.memprofile or False,
nocache=args.nocache or False,
reporter=args.reporter,
config=args.config,
config_filepath=args.config,
emit=args.emit,
dryrun=args.dryrun or False,
init=args.init or False,
overlay_defaults=args.overlay_defaults or False,
skip_validations=args.skip_validations or False,
)

indexer.run()

0 comments on commit 1adce00

Please sign in to comment.