-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag_implementation.py
87 lines (69 loc) · 2.74 KB
/
rag_implementation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import google.generativeai as genai
from dotenv import load_dotenv
import os
import json
from typing import List, Dict
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import GooglePalmEmbeddings
from langchain.chains import RetrievalQA
from langchain.llms import GooglePalm
from langchain_community.vectorstores import Chroma
load_dotenv()
# Configure the Google Gemini API
google_api_key = os.getenv("GOOGLE_API_KEY")
if google_api_key is None:
raise ValueError("GOOGLE_API_KEY not found in environment variables. Please set it in your .env file.")
os.environ["GOOGLE_API_KEY"] = google_api_key
def load_document(file_path: str) -> str:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
def split_text(text: str) -> List[str]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
return text_splitter.split_text(text)
def create_vector_store(chunks: List[str]) -> Chroma:
embeddings = GooglePalmEmbeddings()
return Chroma.from_texts(chunks, embeddings)
def setup_rag_pipeline(vector_store: Chroma) -> RetrievalQA:
llm = GooglePalm()
return RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(),
return_source_documents=True
)
def generate_answer(qa_chain: RetrievalQA, question: str) -> Dict[str, str]:
result = qa_chain({"query": question})
return {
"question": question,
"answer": result["result"],
"contexts": [doc.page_content for doc in result["source_documents"]]
}
def implement_rag():
# Load and process the document
doc_path = "docs/intro-to-llms-karpathy.txt"
document = load_document(doc_path)
chunks = split_text(document)
# Create vector store
vector_store = create_vector_store(chunks)
# Set up RAG pipeline
qa_chain = setup_rag_pipeline(vector_store)
return qa_chain
def generate_answers(questions: List[Dict[str, str]], qa_chain: RetrievalQA) -> List[Dict[str, str]]:
return [generate_answer(qa_chain, q["question"]) for q in questions]
if __name__ == "__main__":
# Implement RAG
qa_chain = implement_rag()
# Load questions from questions.json
with open('questions.json', 'r') as f:
questions = json.load(f)
# Generate answers
answers = generate_answers(questions, qa_chain)
# Save answers to a file
with open('answers.json', 'w') as f:
json.dump(answers, f, indent=2)
print("Answers generated and saved to answers.json")