Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions evals/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
),
]

questions = [
"From the NIKE 2023 10k, Which operating segment contributed least to total Nike brand revenue in fiscal 2023?",
"From the NIKE 2023 10k, Which operating segment contributed most to total Nike brand revenue in fiscal 2023?",
"From the NIKE 2023 10k, Based on the North America revenue table for fiscal 2023, which segment contributed least to North America revenue?",
"From the NIKE 2023 10k, Based on the North America revenue table for fiscal 2023, which segment contributed most to North America revenue?",
]
target_answers = [
"Global Brand Divisions",
"North America",
"Equipment",
"Footwear",
]
dataset = [{"question": q, "target": t} for q, t in zip(questions, target_answers)]


def run_evals(
trace_id: str,
Expand Down
8 changes: 6 additions & 2 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from dotenv import load_dotenv
from langchain.document_loaders import UnstructuredFileLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Redis
from langchain_community.embeddings import HuggingFaceEmbeddings

from rag.config import EMBED_MODEL, REDIS_URL, INDEX_NAME, INDEX_SCHEMA
from rag.config import REDIS_URL, INDEX_NAME, INDEX_SCHEMA, EMBED_MODEL

load_dotenv()

Expand Down Expand Up @@ -44,3 +44,7 @@ def ingest_documents():
redis_url=REDIS_URL,
)
print("Finished.")


if __name__ == "__main__":
ingest_documents()
4,219 changes: 2,195 additions & 2,024 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ sse-starlette = "^1.8.2"
sentence-transformers = "^2.2.2"
parea-ai = "^0.2.25"
unstructured = {extras = ["pdf"], version = "^0.11.6"}
kay = "^0.1.2"

[tool.poetry.group.dev.dependencies]
langchain-cli = "^0.0.20"
Expand Down
127 changes: 22 additions & 105 deletions rag/chain.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,40 @@
import os

from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.vectorstores import Redis
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from parea import Parea
from parea import trace_insert, trace
from parea.evals.general import answer_matches_target_llm_grader_factory
from parea.evals.rag import percent_target_supported_by_context_factory
from parea.utils.trace_integrations.langchain import PareaAILangchainTracer

from evals.eval import run_evals
from rag.config import INDEX_NAME, INDEX_SCHEMA, REDIS_URL, EMBED_MODEL
from evals.eval import dataset
from rag.set_up import DocumentationChain, p

load_dotenv()

# Need to instantiate Parea for tracing and evals
p = Parea(api_key=os.getenv("PAREA_API_KEY"))

# Init Tracer which will send logs to Parea AI
parea_tracer = PareaAILangchainTracer()

# Init Embeddings
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)

# Connect to preloaded vectorstore ( After you run ingest.py )
vectorstore = Redis.from_existing_index(
embedding=embedder, index_name=INDEX_NAME, schema=INDEX_SCHEMA, redis_url=REDIS_URL
)

# Init retriever
retriever = vectorstore.as_retriever(search_type="mmr")

# Init global variable to store context from retriever
context_from_retriever = ""


def _format_docs(docs) -> str:
"""
Format the docs retrieved from the retriever
:param docs: list of documents retrieved from the retriever
:return: str of concatenated documents
"""
global context_from_retriever
formatted_context = "\n\n".join(doc.page_content for doc in docs)
context_from_retriever = formatted_context
return formatted_context


# A sentinel value to indicate whether to add the source to the prompt
# When running evals such as exact match excluding the source could help the LLM judge the answer better
ADD_SOURCE = False
ADD_SOURCE_TEXT = """Include the 'source' and 'start_index from the metadata included in the context you used to
answer the question"""

# Define our prompt template
template = """
Use the following pieces of context from Nike's financial 10k filings
dataset to answer the question. Do not make up an answer if there is no
context provided to help answer it. {ADD_SOURCE_TEXT}

Context:
---------
{context}
dc = DocumentationChain()

---------
Question: {question}
---------

Answer:
"""

prompt = ChatPromptTemplate.from_template(template).partial(
ADD_SOURCE_TEXT=ADD_SOURCE_TEXT if ADD_SOURCE else ""
@trace(
eval_funcs=[
answer_matches_target_llm_grader_factory(),
percent_target_supported_by_context_factory(),
]
)


# RAG Chain
model = ChatOpenAI(model_name="gpt-3.5-turbo-16k")
chain = (
RunnableParallel(
{"context": retriever | _format_docs, "question": RunnablePassthrough()}
)
| prompt
| model
| StrOutputParser()
)


def invoke(question: str) -> tuple[str, str]:
"""
Invoke the chain with the question
:param question:
:return: response and trace_id
"""
response = chain.invoke(question, config={"callbacks": [parea_tracer]})
trace_id = parea_tracer.get_parent_trace_id()
return response, str(trace_id)


def run_chain(question: str, target: str, run_eval: bool):
def run_chain(question: str) -> str:
"""
Run the chain with the question and target answer and optionally run evals
:param question: question to ask
:param target: target answer
:param run_eval: whether to run evals

:return: None
:return: str
"""
response, trace_id = invoke(question)
print("Question: ", question, "\n")
print("Response: ", response, "\n")
output = dc.get_chain().invoke(
{"question": question},
config={"callbacks": [parea_tracer]},
)
trace_insert({"inputs": {"context": dc.get_context()}})
return output


if run_eval:
print("Evals started in thread: \n")
run_evals(
trace_id=trace_id,
question=question,
context=context_from_retriever,
response=response,
target_answer=target,
)
if __name__ == "__main__":
p.experiment(data=dataset, func=run_chain).run()
25 changes: 25 additions & 0 deletions rag/dataset_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from parea import trace
from parea.evals.general import answer_matches_target_llm_grader_factory as llm_grader
from parea.evals.rag import (
percent_target_supported_by_context_factory as context_quality,
)
from parea.schemas import Completion

from rag.set_up import p


@trace(eval_funcs=[llm_grader(), context_quality()])
def run_dataset_chain(question: str, context: str) -> str:
return p.completion(
Completion(
deployment_id="p-T8VBzfZnwlVOUeD0aOwdu",
llm_inputs={"context": context, "question": question},
)
).content


if __name__ == "__main__":
p.experiment(
data="RAG2",
func=run_dataset_chain,
).run(name="rag_original_chain")
61 changes: 61 additions & 0 deletions rag/set_up.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
from operator import itemgetter

from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain_community.retrievers import KayAiRetriever
from langchain_core.output_parsers import StrOutputParser
from parea import Parea
from parea.schemas.models import UseDeployedPrompt

load_dotenv()

# Need to instantiate Parea for tracing and evals
p = Parea(api_key=os.getenv("PAREA_API_KEY"))


class DocumentRetriever:
def __init__(self):
# embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
# vectorstore = Redis.from_existing_index(
# embedding=embedder,
# index_name=INDEX_NAME,
# schema=INDEX_SCHEMA,
# redis_url=REDIS_URL,
# )
# self.retriever = vectorstore.as_retriever(search_type="mmr")
self.retriever = KayAiRetriever.create(
dataset_id="company", data_types=["10-K"], num_contexts=3
)

def get_retriever(self):
return self.retriever


class DocumentationChain:
def __init__(self):
retriever = DocumentRetriever().get_retriever()
fetched = p.get_prompt(
UseDeployedPrompt(deployment_id="p-T8VBzfZnwlVOUeD0aOwdu")
)
template = fetched.prompt.raw_messages[0]["content"]
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI(model_name=fetched.model)
response_generator = prompt | model | StrOutputParser()
self.chain = {
"context": itemgetter("question") | retriever | self._format_docs,
"question": itemgetter("question"),
} | response_generator

def get_context(self) -> str:
"""Helper to get the context from a retrieval chain, so we can use it for evaluation metrics."""
return self.context

def _format_docs(self, docs) -> str:
context = "\n\n".join(doc.page_content for doc in docs)
self.context = context
return context

def get_chain(self):
return self.chain