Skip to content

Commit

Permalink
Revert to writing to a file
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Apr 29, 2024
1 parent 306439c commit 1cdb3b9
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions with_openai/assets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from filelock import FileLock
import os
import pickle
from typing import Dict, List, Any

from dagster import (
Expand All @@ -22,7 +24,7 @@
#from langchain_community.vectorstores import FAISS
from langchain.vectorstores.faiss import FAISS

from .constants import SUMMARY_TEMPLATE
from .constants import SEARCH_INDEX_FILE, SUMMARY_TEMPLATE
from .utils import get_github_docs

docs_partitions_def = StaticPartitionsDefinition(
Expand All @@ -45,14 +47,13 @@
io_manager_key = "s3_io_manager"


# io_manager_key="fs_io_manager"
@asset(compute_kind="GitHub", partitions_def=docs_partitions_def)
def source_docs(context: AssetExecutionContext):
return list(get_github_docs("dagster-io", "dagster", context.partition_key))


@asset(compute_kind="OpenAI", partitions_def=docs_partitions_def)
def search_index(context: AssetExecutionContext, openai: OpenAIResource, source_docs: List[Any]):
def search_index(context: AssetExecutionContext, openai: OpenAIResource, source_docs):
source_chunks = []
splitter = CharacterTextSplitter(separator=" ", chunk_size=1024, chunk_overlap=0)
for source in source_docs:
Expand All @@ -65,44 +66,42 @@ def search_index(context: AssetExecutionContext, openai: OpenAIResource, source_
source_chunks, OpenAIEmbeddings(client=client.embeddings)
)

return search_index.serialize_to_bytes()
with FileLock(SEARCH_INDEX_FILE):
if os.path.getsize(SEARCH_INDEX_FILE) > 0:
with open(SEARCH_INDEX_FILE, "rb") as f:
serialized_search_index = pickle.load(f)
cached_search_index = FAISS.deserialize_from_bytes(
serialized_search_index, OpenAIEmbeddings()
)
search_index.merge_from(cached_search_index)

with open(SEARCH_INDEX_FILE, "wb") as f:
pickle.dump(search_index.serialize_to_bytes(), f)


class OpenAIConfig(Config):
model: str
question: str


@asset(
compute_kind="OpenAI",
#ins={
# "search_index": AssetIn(partition_mapping=AllPartitionMapping()),
#},
)
@asset(compute_kind="OpenAI", deps=[search_index])
def completion(
context: AssetExecutionContext,
openai: OpenAIResource,
config: OpenAIConfig,
search_index: Dict[str, Any]
context: AssetExecutionContext,
openai: OpenAIResource,
config: OpenAIConfig,
):
context.log.info(search_index.values())
merged_index = None
for index in search_index.values():
curr = FAISS.deserialize_from_bytes(index, OpenAIEmbeddings())
if not merged_index:
merged_index = curr
else:
merged_index.merge_from(FAISS.deserialize_from_bytes(index, OpenAIEmbeddings()))
with open(SEARCH_INDEX_FILE, "rb") as f:
serialized_search_index = pickle.load(f)
search_index = FAISS.deserialize_from_bytes(serialized_search_index, OpenAIEmbeddings())
with openai.get_client(context) as client:
prompt = stuff_prompt.PROMPT
model = ChatOpenAI(client=client.chat.completions, model=config.model, temperature=0)
summaries = " ".join(
[
SUMMARY_TEMPLATE.format(content=doc.page_content, source=doc.metadata["source"])
for doc in merged_index.similarity_search(config.question, k=4)
for doc in search_index.similarity_search(config.question, k=4)
]
)
context.log.info(summaries)
output_parser = StrOutputParser()
chain = prompt | model | output_parser
context.log.info(chain.invoke({"summaries": summaries, "question": config.question}))
Expand Down

0 comments on commit 1cdb3b9

Please sign in to comment.