Skip to content

fix(builder): multiple fixes to the KAG core code, improve stability #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions kag/builder/component/mapping/spg_type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import defaultdict
from typing import Dict, List, Callable

import math
import pandas

from knext.schema.client import BASIC_TYPES
Expand Down Expand Up @@ -126,6 +127,10 @@ def assemble_sub_graph(self, properties: Dict[str, str]):
prop = self.spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
# If property key exists but value is empty (NaN), skip
if not prop_value or (type(prop_value) is float and math.isnan(prop_value)):
continue

prop_value_list = prop_value.split(",")
for o_id in prop_value_list:
if prop_name in self.link_funcs:
Expand Down
11 changes: 9 additions & 2 deletions kag/builder/prompt/default/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,15 @@ def parse_response(self, response: str, **kwargs):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
entities = rsp["named_entities"]
# - case1: {'named_entities': [...]}
# - case2: {'entities': [...]}
if isinstance(rsp, dict):
if "named_entities" in rsp:
entities = rsp["named_entities"]
elif "entities" in rsp:
entities = rsp["entities"]
else:
entities = rsp
else:
entities = rsp

Expand Down
24 changes: 22 additions & 2 deletions kag/builder/prompt/default/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,36 @@ def parse_response(self, response: str, **kwargs):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
standardized_entity = rsp["named_entities"]
# - case1: {'named_entities': [...]}
# - case2: {'entities': [...]}
if isinstance(rsp, dict):
if "named_entities" in rsp:
standardized_entity = rsp["named_entities"]
elif "entities" in rsp:
standardized_entity = rsp["entities"]
else:
standardized_entity = rsp
else:
standardized_entity = rsp

# In some cases LLM returns a JSON string
if type(standardized_entity) == str:
standardized_entity = json.loads(standardized_entity)

entities_with_offical_name = set()
merged = []
entities = kwargs.get("named_entities", [])
for entity in standardized_entity:
# LLM could also build result as a list of string
if type(entity) == str:
entity = json.loads(entity)
merged.append(entity)
entities_with_offical_name.add(entity["name"])

# Entities should be a list of dict, but LLM could return a dict {'entities': [...]}
if type(entities) != list:
entities = entities.get('entities', list())

# in case llm ignores some entities
for entity in entities:
if entity["name"] not in entities_with_offical_name:
Expand Down
18 changes: 16 additions & 2 deletions kag/common/registry/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,26 @@ def can_construct_from_config(type_: Type) -> bool:
def remove_optional(annotation: type) -> type:
"""
Remove Optional[X](alias of Union[T, None]) annotations by filtering out NoneType from Union[X, NoneType].
- Also supports both traditional Optional[X] syntax and Python 3.10+ pipe syntax (X | None)
"""
origin = get_origin(annotation)
args = get_args(annotation)

if origin == Union:
return Union[tuple([arg for arg in args if arg != type(None)])] # noqa
if origin == Union or origin is types.UnionType: # Support both Union and pipe syntax
filtered_args = tuple([arg for arg in args if arg != type(None)]) # noqa

# If only one type left after filtering None, return directly
if len(filtered_args) == 1:
return filtered_args[0]
# Otherwise return the Union of the remaining types
elif len(filtered_args) > 1:
if origin is types.UnionType: # For pipe syntax
return types.UnionType[filtered_args] # type: ignore
else: # For traditional Union
return Union[filtered_args]
# If all types were None, just return the original annotation
else:
return annotation
else:
return annotation

Expand Down
18 changes: 16 additions & 2 deletions kag/common/vectorize_model/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,18 @@ def vectorize(
Returns:
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
"""
# Some models require a list of strings as input rather than a single string, otherwise the model will generate
# vector for each character in the input string. Convert single string to list of strings to unify the input format.
if isinstance(texts, str):
texts = [texts]

results = self.client.embeddings.create(
input=texts, model=self.model, timeout=self.timeout
)
results = [item.embedding for item in results.data]
if isinstance(texts, str):

# If the input is a single string or a list with only one element, return the first element of the results list.
if isinstance(texts, str) or len(texts) == 1:
assert len(results) == 1
return results[0]
else:
Expand Down Expand Up @@ -125,11 +132,18 @@ def vectorize(
Returns:
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
"""
# Some models require a list of strings as input rather than a single string, otherwise the model will generate
# vector for each character in the input string. Convert single string to list of strings to unify the input format.
if isinstance(texts, str):
texts = [texts]

results = self.client.embeddings.create(
input=texts, model=self.model, timeout=self.timeout
)
results = [item.embedding for item in results.data]
if isinstance(texts, str):

# If the input is a single string or a list with only one element, return the first element of the results list.
if isinstance(texts, str) or len(texts) == 1:
assert len(results) == 1
return results[0]
else:
Expand Down
3 changes: 2 additions & 1 deletion kag/solver/execute/default_lf_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def _execute_chunk_answer(
all_related_entities = list(set(all_related_entities))
sub_query = self._generate_sub_query_with_history_qa(history, lf.query)
doc_retrieved = self.chunk_retriever.recall_docs(
queries=[query, sub_query],
# sub_query could be the same as original query
queries=[query, sub_query] if sub_query != query else [query],
retrieved_spo=all_related_entities,
kwargs=self.params,
)
Expand Down
6 changes: 6 additions & 0 deletions kag/solver/logic/core_modules/common/text_sim_by_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def sentence_encode(self, sentences, is_cached=False):
need_call_emb_text.append(text)
if len(need_call_emb_text) > 0:
emb_res = self.vectorize_model.vectorize(need_call_emb_text)

# If need_call_emb_text has only one element and emb_res is not of type list[list], convert emb_res to a list of lists
if len(need_call_emb_text) == 1:
if emb_res and type(emb_res[0]) != list:
emb_res = [emb_res]

for text, text_emb in zip(need_call_emb_text, emb_res):
tmp_map[text] = text_emb
if is_cached:
Expand Down
11 changes: 9 additions & 2 deletions kag/solver/prompt/default/question_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,15 @@ def parse_response(self, response: str, **kwargs):
rsp = json.loads(rsp)
if isinstance(rsp, dict) and "output" in rsp:
rsp = rsp["output"]
if isinstance(rsp, dict) and "named_entities" in rsp:
entities = rsp["named_entities"]
# - case1: {'named_entities': [...]}
# - case2: {'entities': [...]}
if isinstance(rsp, dict):
if "named_entities" in rsp:
entities = rsp["named_entities"]
elif "entities" in rsp:
entities = rsp["entities"]
else:
entities = rsp
else:
entities = rsp

Expand Down
Loading