forked from UOFT-DSI-SRI-ResponsibleLLM-Hackathon/team1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdense_retriever.py
158 lines (131 loc) · 6.49 KB
/
dense_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import faiss
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
class DPRRetriever:
def __init__(self, documents, top_k=5, device=None, max_length=1024):
"""
Initializes the DPRRetriever with a list of documents.
Args:
documents (list of str): The corpus of documents to search.
top_k (int): The number of top relevant documents to retrieve.
device (str, optional): Device to run the models on ('cpu' or 'cuda'). Defaults to 'cuda' if available.
max_length (int, optional): Maximum token length for documents and queries.
"""
self.documents = documents
self.top_k = top_k
self.max_length = max_length
# Set device
if device:
self.device = device
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize DPR encoders and tokenizers with max_length
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
'facebook/dpr-ctx_encoder-single-nq-base',
model_max_length=self.max_length
)
self.ctx_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(self.device)
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base',
model_max_length=self.max_length
)
self.q_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(self.device)
# Encode all documents and build FAISS index
self._build_index()
print(f"[DPRRetriever] Initialized with {len(self.documents)} documents. Using device: {self.device}")
def encode(self):
# Encode documents
ctx_input = self.ctx_tokenizer(
self.documents,
padding=True,
truncation=True,
max_length=self.max_length, # Ensure truncation
return_tensors='pt'
)
with torch.no_grad():
ctx_embeddings = self.ctx_encoder(
input_ids=ctx_input['input_ids'].to(self.device),
attention_mask=ctx_input['attention_mask'].to(self.device)
).pooler_output # Shape: (num_docs, hidden_size)
return ctx_embeddings
def _build_index(self):
"""
Encodes all documents and builds the FAISS index.
"""
# Encode documents
ctx_input = self.ctx_tokenizer(
self.documents,
padding=True,
truncation=True,
max_length=self.max_length, # Ensure truncation
return_tensors='pt'
)
with torch.no_grad():
ctx_embeddings = self.ctx_encoder(
input_ids=ctx_input['input_ids'].to(self.device),
attention_mask=ctx_input['attention_mask'].to(self.device)
).pooler_output # Shape: (num_docs, hidden_size)
# Normalize embeddings
ctx_embeddings = torch.nn.functional.normalize(ctx_embeddings, p=2, dim=1).cpu().numpy()
# Determine the dimensionality
dimension = ctx_embeddings.shape[1]
# Initialize FAISS index (Inner Product is equivalent to cosine similarity since vectors are normalized)
self.index = faiss.IndexFlatIP(dimension)
# Add embeddings to the index
self.index.add(ctx_embeddings)
def retrieve(self, query):
"""
Retrieves the top_k most relevant documents for the given query.
Args:
query (str): The user's input query.
Returns:
list of str: The top_k most relevant documents.
"""
# Encode the query
q_input = self.q_tokenizer(
query,
padding=True,
truncation=True,
max_length=self.max_length, # Ensure truncation
return_tensors='pt'
)
with torch.no_grad():
q_embedding = self.q_encoder(
input_ids=q_input['input_ids'].to(self.device),
attention_mask=q_input['attention_mask'].to(self.device)
).pooler_output # Shape: (1, hidden_size)
# Normalize the query embedding
q_embedding = torch.nn.functional.normalize(q_embedding, p=2, dim=1).cpu().numpy()
# Search in the FAISS index
scores, indices = self.index.search(q_embedding, self.top_k)
# Retrieve documents based on indices
retrieved_docs = [self.documents[idx] for idx in indices[0]]
print(f"[DPRRetriever] Retrieved top {self.top_k} documents for the query.")
return retrieved_docs, indices
db = [
"Dehydration occurs when your body loses more fluids than it takes in. Symptoms include dry mouth, fatigue, dizziness, and decreased urine output.",
"Hydration is essential for maintaining bodily functions. Common signs of adequate hydration include regular urination and moist skin.",
"Severe dehydration can lead to serious complications such as heatstroke, kidney failure, and seizures.",
"Mild dehydration can often be remedied by drinking water or electrolyte-rich beverages.",
"Athletes are particularly susceptible to dehydration and should monitor their fluid intake closely during training and competition.",
"Dehydration occurs when your body loses more fluids than it takes in. Symptoms include dry mouth, fatigue, dizziness, and decreased urine output.",
"Hydration is essential for maintaining bodily functions. Common signs of adequate hydration include regular urination and moist skin.",
"Severe dehydration can lead to serious complications such as heatstroke, kidney failure, and seizures.",
"Mild dehydration can often be remedied by drinking water or electrolyte-rich beverages.",
"Athletes are particularly susceptible to dehydration and should monitor their fluid intake closely during training and competition."
]
# Initialize the DPRRetriever
retriever = DPRRetriever(documents=db, top_k=3)
# User Query
user_query = "What are athletes susceptible to?"
# Retrieve Top-K Documents
embeddings = retriever.encode()
retrieved_docs, chunk_indices = retriever.retrieve(user_query)
print(chunk_indices.squeeze()) # this is the list of chunk indices that you will need to index into the db
# Display Retrieved Documents
print("\nRetrieved Documents:")
for idx, doc in enumerate(retrieved_docs, 1):
print(f"{idx}. {doc}")