Skip to content

Commit

Permalink
Merge pull request #9 from robertmcleod2/add-rag
Browse files Browse the repository at this point in the history
Add rag
  • Loading branch information
robertmcleod2 authored Oct 7, 2024
2 parents 652b93b + f98aebc commit a88131e
Show file tree
Hide file tree
Showing 12 changed files with 556 additions and 70 deletions.
11 changes: 6 additions & 5 deletions .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ jobs:
contents: read
# retrieve variables from environment secrets and variables
env:
# deplyoment variables
# deployment variables
AZURE_SUBSCRIPTION_ID: ${{ vars.AZURE_SUBSCRIPTION_ID }}
AZURE_TENANT_ID: ${{ vars.AZURE_TENANT_ID }}
AZURE_CLIENT_ID: ${{ vars.AZURE_CLIENT_ID }}
AZURE_CLIENT_CERT_NAME: ${{ vars.AZURE_CLIENT_CERT_NAME }}
AZURE_CLIENT_CERT: ${{ secrets.AZURE_CLIENT_CERT }}
# blank variable to store the certificate content
AZURE_CLIENT_CERTIFICATE_PATH: "blank"
AZURE_CLIENT_CERTIFICATE: ${{ secrets.AZURE_CLIENT_CERTIFICATE }}
AZURE_RESOURCE_GROUP: ${{ vars.AZURE_RESOURCE_GROUP }}
# App env variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand All @@ -38,11 +39,11 @@ jobs:
- name: run deploy.sh
shell: bash
run: |
echo "$AZURE_CLIENT_CERT" > ${{ vars.AZURE_CLIENT_CERT_NAME }}
echo "$AZURE_CLIENT_CERTIFICATE" > ${AZURE_CLIENT_CERTIFICATE_PATH}
bash deployment/deploy.sh
- name: clean up
if: always()
shell: bash
run: |
az logout
rm ${{ vars.AZURE_CLIENT_CERT_NAME }}
rm ${AZURE_CLIENT_CERTIFICATE_PATH}
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,6 @@ pip install -r requirements_dev.txt

```bash
streamlit run src/app.py
```
```

6. View the application in your browser at `http://localhost:8501`. The password is the one you set in the `.env` file.
2 changes: 1 addition & 1 deletion deployment/deploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export $(egrep -v '^#' .env | xargs)
LOCATION=westeurope

# log in to Azure with service principal
az login --service-principal --username ${AZURE_CLIENT_ID} --password ${AZURE_CLIENT_CERT_NAME} --tenant ${AZURE_TENANT_ID}
az login --service-principal --username ${AZURE_CLIENT_ID} --password ${AZURE_CLIENT_CERTIFICATE_PATH} --tenant ${AZURE_TENANT_ID}
az account set --subscription ${AZURE_SUBSCRIPTION_ID}
az config set defaults.group=${AZURE_RESOURCE_GROUP} defaults.location=${LOCATION}

Expand Down
2 changes: 1 addition & 1 deletion deployment/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export $(egrep -v '^#' .env | xargs)
LOCATION=westeurope

# log in to Azure with service principal
az login --service-principal --username ${AZURE_CLIENT_ID} --password ${AZURE_CLIENT_CERT_NAME} --tenant ${AZURE_TENANT_ID}
az login --service-principal --username ${AZURE_CLIENT_ID} --password ${AZURE_CLIENT_CERTIFICATE_PATH} --tenant ${AZURE_TENANT_ID}
az account set --subscription ${AZURE_SUBSCRIPTION_ID}
az config set defaults.group=${AZURE_RESOURCE_GROUP} defaults.location=${LOCATION}

Expand Down
8 changes: 4 additions & 4 deletions basic_chatbot.ipynb → notebooks/basic_chatbot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@
"\n",
"store = {}\n",
"\n",
"\n",
"def get_session_history(session_id: str) -> BaseChatMessageHistory:\n",
" if session_id not in store:\n",
" store[session_id] = InMemoryChatMessageHistory()\n",
" return store[session_id]\n",
"\n",
"\n",
"trimmer = trim_messages(\n",
" max_tokens=1000,\n",
" strategy=\"last\",\n",
Expand All @@ -66,15 +68,13 @@
" The user will describe their energy usage and you will help them to detect anomalies.\\n\n",
" You will also help the user to identify the causes of the anomalies \\n\n",
" and suggest ways to fix them.\\n\n",
" \"\"\"\n",
" \"\"\",\n",
" ),\n",
" MessagesPlaceholder(variable_name=\"messages\"),\n",
" ]\n",
")\n",
"\n",
"chain = (\n",
" trimmer | prompt | model\n",
")\n",
"chain = trimmer | prompt | model\n",
"\n",
"with_message_history = RunnableWithMessageHistory(\n",
" chain,\n",
Expand Down
210 changes: 202 additions & 8 deletions exploration.ipynb → notebooks/exploration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,19 @@
"outputs": [],
"source": [
"response = with_message_history.invoke(\n",
" [HumanMessage(content=\"\"\"list the names of people who have said hello to you in our conversation so far.\\n\n",
" [\n",
" HumanMessage(\n",
" content=\"\"\"list the names of people who have said hello to you in our conversation so far.\\n\n",
"List the names following the below format:\\n\n",
"\n",
"[<Name 1>, <Name 2>, <Name 3>, ...]\n",
"\n",
"For example, if the names of people who have said hello to you in our conversation so far are Alice and Bob, you should list them as follows:\n",
" \n",
"[Alice, Bob] \n",
"\"\"\")],\n",
"\"\"\"\n",
" )\n",
" ],\n",
" config=config,\n",
")\n",
"\n",
Expand Down Expand Up @@ -291,11 +295,7 @@
" AIMessage(content=\"yes!\"),\n",
"]\n",
"\n",
"chain = (\n",
" RunnablePassthrough.assign(messages=itemgetter(\"messages\") | trimmer)\n",
" | prompt\n",
" | model\n",
")\n",
"chain = RunnablePassthrough.assign(messages=itemgetter(\"messages\") | trimmer) | prompt | model\n",
"\n",
"with_message_history = RunnableWithMessageHistory(\n",
" chain,\n",
Expand All @@ -313,7 +313,7 @@
" config=config,\n",
")\n",
"\n",
"response.content\n"
"response.content"
]
},
{
Expand All @@ -339,6 +339,200 @@
"):\n",
" print(r.content, end=\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chatbot with memory and RAG\n",
"\n",
"chatbot uses retriever to get relevant information to query from vector store, for every query. \n",
"\n",
"this means we are taking a while to get the information, but we are getting the most relevant information. \n",
"\n",
"probably not the most useful for our application as for now we are normally just going to want to get the information related to energy anomaly detection and use that every time as context.\n",
"\n",
"Maybe we would want to give the chatbot agent the ability to decide if it wants to retrive information from the vector store or not."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"from dotenv import load_dotenv\n",
"from langchain.chains import create_history_aware_retriever, create_retrieval_chain\n",
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
"from langchain.globals import set_debug\n",
"from langchain_community.chat_message_histories import ChatMessageHistory\n",
"from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain_core.runnables.history import RunnableWithMessageHistory\n",
"from langchain_core.vectorstores import InMemoryVectorStore\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"from langchain_text_splitters import RecursiveJsonSplitter\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
"\n",
"### Construct retriever ###\n",
"with open(\"../src/example_customer_documents.json\") as f:\n",
" json_data = json.load(f)\n",
"\n",
"splitter = RecursiveJsonSplitter(max_chunk_size=300)\n",
"docs = splitter.create_documents(texts=[json_data])\n",
"\n",
"vectorstore = InMemoryVectorStore.from_documents(documents=docs, embedding=OpenAIEmbeddings())\n",
"retriever = vectorstore.as_retriever()\n",
"\n",
"\n",
"### Contextualize question ###\n",
"contextualize_q_system_prompt = \"\"\"Given a chat history and the latest user question \\\n",
"which might reference context in the chat history, formulate a standalone question \\\n",
"which can be understood without the chat history. \\\n",
"Include information relevant to energy anomaly detection in the question, if needed. \\\n",
"Do NOT answer the question, just reformulate it if needed and otherwise return it as is.\"\"\"\n",
"contextualize_q_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", contextualize_q_system_prompt),\n",
" MessagesPlaceholder(\"chat_history\"),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)\n",
"\n",
"\n",
"### Answer question ###\n",
"qa_system_prompt = \"\"\"You are an an energy usage anomaly detection assistant. \\\n",
"You are helping a user to detect anomalies in their energy usage. \\\n",
"The user will describe their energy usage and you will help them to detect anomalies. \\\n",
"You will also help the user to identify the causes of the anomalies \\\n",
"and suggest ways to fix them. \\\n",
"\n",
"Use the following pieces of retrieved context to answer the question if needed. \\\n",
"Follow up on previous parts of customer service chatbot and agent conversations that are not yet resolved. \\\n",
"\n",
"{context}\"\"\"\n",
"\n",
"qa_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", qa_system_prompt),\n",
" MessagesPlaceholder(\"chat_history\"),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)\n",
"\n",
"rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)\n",
"\n",
"\n",
"### Statefully manage chat history ###\n",
"chat_history = ChatMessageHistory()\n",
"\n",
"\n",
"conversational_rag_chain = RunnableWithMessageHistory(\n",
" rag_chain,\n",
" lambda session_id: chat_history,\n",
" input_messages_key=\"input\",\n",
" history_messages_key=\"chat_history\",\n",
" output_messages_key=\"answer\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.globals import set_debug\n",
"\n",
"set_debug(False)\n",
"\n",
"conversational_rag_chain.invoke({\"input\": \"I increased my thermostat temperature during the day\"}, {\"configurable\": {\"session_id\": \"unused\"}})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chatbot that is passed a fixed context RAG\n",
"\n",
"Simplest implementation of RAG: just passing a fixed context from the vector store using the retriever and a fixed query. "
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"from dotenv import load_dotenv\n",
"from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain_core.runnables.history import RunnableWithMessageHistory\n",
"from langchain_core.vectorstores import InMemoryVectorStore\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"from langchain_text_splitters import RecursiveJsonSplitter\n",
"\n",
"load_dotenv()\n",
"\n",
"model = ChatOpenAI(model=\"gpt-4o-mini\")\n",
"\n",
"# local document retrieval\n",
"with open(\"../src/example_customer_documents.json\") as f:\n",
" json_data = json.load(f)\n",
"\n",
"splitter = RecursiveJsonSplitter(max_chunk_size=300)\n",
"docs = splitter.create_documents(texts=[json_data])\n",
"\n",
"vectorstore = InMemoryVectorStore.from_documents(documents=docs, embedding=OpenAIEmbeddings())\n",
"retriever = vectorstore.as_retriever()\n",
"\n",
"context = retriever.invoke(\"energy anomaly detection smart meter energy consumption\")\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"\"\"You are an an energy usage anomaly detection assistant.\\n\n",
" You are helping a user to detect anomalies in their energy usage.\\n\n",
" The user will describe their energy usage and you will help them to detect anomalies.\\n\n",
" You will also help the user to identify the causes of the anomalies \\n\n",
" and suggest ways to fix them.\\n\n",
"\n",
" Use the following pieces of retrieved context to answer the question if needed.\\n\n",
" {context}\n",
" \"\"\",\n",
" ),\n",
" MessagesPlaceholder(variable_name=\"chat_history\"),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | model\n",
"\n",
"### Statefully manage chat history ###\n",
"chat_history = ChatMessageHistory()\n",
"\n",
"chain_with_message_history = RunnableWithMessageHistory(\n",
" chain,\n",
" lambda session_id: chat_history,\n",
" input_messages_key=\"input\",\n",
" history_messages_key=\"chat_history\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain_with_message_history.invoke({\"context\": context, \"input\": \"I increased my thermostat temperature during the day\"}, {\"configurable\": {\"session_id\": \"unused\"}})"
]
}
],
"metadata": {
Expand Down
9 changes: 4 additions & 5 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import streamlit as st
from basic_chatbot import setup_chain
from chatbot_rag import ChatbotRAG as Chatbot
from dotenv import load_dotenv
from utils import check_password

Expand All @@ -10,8 +10,8 @@

st.title("Energy Usage Anomaly Detection Assistant")

if "chain" not in st.session_state:
st.session_state.chain = setup_chain()
if "chatbot" not in st.session_state:
st.session_state.chatbot = Chatbot()

if "messages" not in st.session_state:
st.session_state.messages = []
Expand All @@ -26,6 +26,5 @@
st.markdown(prompt)

with st.chat_message("assistant"):
stream = st.session_state.chain.stream({"input": prompt}, {"configurable": {"session_id": "unused"}})
response = st.write_stream(stream)
response = st.session_state.chatbot.stream(prompt)
st.session_state.messages.append({"role": "assistant", "content": response})
Loading

0 comments on commit a88131e

Please sign in to comment.