Skip to content

Commit

Permalink
0.3.2版本graphrag适配;
Browse files Browse the repository at this point in the history
对部分代码逻辑做了优化,去除了一些冗余代码。
  • Loading branch information
张建平 committed Nov 27, 2024
1 parent cf9f65b commit 5a6786c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ search_runner.run_local_search(query="What are the top themes in this story?", s
search_runner.run_global_search(query="What are the top themes in this story?", streaming=False)

# 对于输出的结果可能带有一些特殊字符,可以采用以下函数去除特殊字符或自行处理。
search_runner.remove_sources(search_runner.run_local_search(query="What are the top themes in this story?"))
search_runner.remove_sources(search_runner.run_local_search(query="What are the top themes in this story?")[0])
```

### 报告问题
Expand Down
17 changes: 13 additions & 4 deletions graphrag_api/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import pandas as pd

from graphrag.config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.index.progress import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.input.loaders.dfs import (
store_entity_semantic_embeddings,
)
from graphrag.query.structured_search.base import SearchResult
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.vector_stores.lancedb import LanceDBVectorStore

Expand Down Expand Up @@ -60,8 +62,10 @@ def __init__(
@validate_call(config={"arbitrary_types_allowed": True})
async def search(search_agent, query):
"""非流式搜索"""
result = await search_agent.asearch(query=query)
return result.response
result: SearchResult = await search_agent.asearch(query=query)
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore
return response, context_data

@staticmethod
@validate_call(config={"arbitrary_types_allowed": True})
Expand Down Expand Up @@ -217,6 +221,10 @@ def __get__local_agent(self):
entities = read_indexer_entities(
final_nodes, final_entities, self.community_level
)
base_dir = Path(str(root_dir)) / config.storage.base_dir
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = self.__get_embedding_description_store(
entities=entities,
vector_store_type=vector_store_type,
Expand Down Expand Up @@ -256,12 +264,13 @@ def _configure_paths_and_settings(
root_dir: str | None,
config_dir: str | None,
) -> tuple[str, str | None, GraphRagConfig]:
config = self._create_graphrag_config(root_dir, config_dir)
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
data_dir = self._infer_data_dir(cast(str, root_dir))
config = self._create_graphrag_config(root_dir, config_dir)
base_dir = Path(str(root_dir)) / config.storage.base_dir
data_dir = str(resolve_timestamp_path(base_dir))
return data_dir, root_dir, config

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
graphrag==0.3.1
graphrag==0.3.2
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name="graphrag_api",
version="0.3.1",
version="0.3.2",
packages=find_packages(exclude=["tests"]),
install_requires=requirements,
author="nightzjp",
Expand Down

0 comments on commit 5a6786c

Please sign in to comment.