|
8 | 8 | from langchain.chat_models import ChatOpenAI
|
9 | 9 | from pydantic import BaseModel
|
10 | 10 |
|
11 |
| -from fastapi_async_langchain.responses import LLMChainStreamingResponse, RetrievalQAStreamingResponse |
| 11 | +from fastapi_async_langchain.responses import ( |
| 12 | + LLMChainStreamingResponse, |
| 13 | + RetrievalQAStreamingResponse, |
| 14 | +) |
12 | 15 |
|
13 | 16 | load_dotenv()
|
14 | 17 |
|
@@ -46,32 +49,38 @@ async def chat(
|
46 | 49 | chain, request.query, media_type="text/event-stream"
|
47 | 50 | )
|
48 | 51 |
|
| 52 | + |
49 | 53 | def retrieval_qa_chain():
|
50 | 54 | from langchain.chains import RetrievalQAWithSourcesChain
|
51 | 55 | from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
52 | 56 | from langchain.chains.qa_with_sources.stuff_prompt import PROMPT as QA_PROMPT
|
53 |
| - from langchain.vectorstores import FAISS |
54 | 57 | from langchain.embeddings import OpenAIEmbeddings
|
| 58 | + from langchain.vectorstores import FAISS |
55 | 59 |
|
56 | 60 | callback_manager = AsyncCallbackManager([])
|
57 |
| - vectorstore = FAISS.load_local(index_name="langchain-python", embeddings=OpenAIEmbeddings(), folder_path="demo/") |
| 61 | + vectorstore = FAISS.load_local( |
| 62 | + index_name="langchain-python", |
| 63 | + embeddings=OpenAIEmbeddings(), |
| 64 | + folder_path="demo/", |
| 65 | + ) |
58 | 66 | retriever = vectorstore.as_retriever()
|
59 |
| - streaming_llm = ChatOpenAI(streaming=True, callback_manager=callback_manager, verbose=True, temperature=0) |
60 |
| - doc_chain = load_qa_with_sources_chain(llm=streaming_llm, |
61 |
| - chain_type="stuff", |
62 |
| - prompt=QA_PROMPT) |
63 |
| - return RetrievalQAWithSourcesChain(combine_documents_chain=doc_chain, |
64 |
| - retriever=retriever, |
65 |
| - callback_manager=callback_manager, |
66 |
| - return_source_documents=True, |
67 |
| - verbose=True) |
| 67 | + streaming_llm = ChatOpenAI( |
| 68 | + streaming=True, callback_manager=callback_manager, verbose=True, temperature=0 |
| 69 | + ) |
| 70 | + doc_chain = load_qa_with_sources_chain( |
| 71 | + llm=streaming_llm, chain_type="stuff", prompt=QA_PROMPT |
| 72 | + ) |
| 73 | + return RetrievalQAWithSourcesChain( |
| 74 | + combine_documents_chain=doc_chain, |
| 75 | + retriever=retriever, |
| 76 | + callback_manager=callback_manager, |
| 77 | + return_source_documents=True, |
| 78 | + verbose=True, |
| 79 | + ) |
| 80 | + |
68 | 81 |
|
69 | 82 | @app.post("/retrieval-qa-with-sources")
|
70 |
| -async def retrieval_qa_with_sources( |
71 |
| - request: Request |
72 |
| -) -> RetrievalQAStreamingResponse: |
| 83 | +async def retrieval_qa_with_sources(request: Request) -> RetrievalQAStreamingResponse: |
73 | 84 | return RetrievalQAStreamingResponse(
|
74 |
| - chain=retrieval_qa_chain(), |
75 |
| - inputs=request.query, |
76 |
| - media_type="text/event-stream" |
| 85 | + chain=retrieval_qa_chain(), inputs=request.query, media_type="text/event-stream" |
77 | 86 | )
|
0 commit comments