-
Notifications
You must be signed in to change notification settings - Fork 1
/
chat_moodle.py
201 lines (165 loc) · 7.07 KB
/
chat_moodle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# ------------------------------------------------------------------------
# Class ChatMoodle
#
# Copyright 2024 Pimenko <[email protected]><pimenko.com>
# Author Jordan Kesraoui
# License https://www.gnu.org/copyleft/gpl.html GNU GPL v3 or later
# ------------------------------------------------------------------------
from loguru import logger
from datetime import datetime
from langchain.schema import AIMessage, SystemMessage, HumanMessage
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chat_models import ChatOllama
from langchain.callbacks import get_openai_callback
from langchain.vectorstores import FAISS
import tiktoken
import os
from dotenv import load_dotenv
load_dotenv()
class ChatMoodle:
def __init__(self, token: str, llm_provider: str, model_name: str, model_cache: str, max_tokens: int, doc_language: str, courseid: int, instruction: str, history: list):
""" Initialize the ChatMoodle class.
Args:
token(str): The Api token
llm_provider(str): The LLM provider
model_name(str): The model name
model_cache(str): The cache directory
max_tokens(int): The max tokens
doc_language(str): The doc lanagage
courseid(int): The moodle courseid
instruction(str): The instruction
history(list): The history of messages
"""
self.token = token
print("Using LLM-provider: " + llm_provider)
self.llm_provider = llm_provider
print("Using Model: " + model_name)
self.model_name = model_name
print("Model Cache directory: " + model_cache)
self.model_cache = model_cache
print("Max Prompt Tokens: " + str(max_tokens))
self.max_tokens = max_tokens
print("Using Language: " + doc_language)
self.doc_language = doc_language
print("On Course ID: " + str(courseid))
self.courseid = str(courseid)
# Initialize chat model.
if llm_provider=="openai":
self.chat = ChatOpenAI(temperature=0, model_name=model_name)
else:
raise ValueError("LLM-provider not recognized. Check LLM_PROVIDER environment variable.")
print("Using local FAISS.")
self.vector_store_dir = "vector_stores/course_" + self.courseid
self.vector_store = FAISS.load_local(self.vector_store_dir,
HuggingFaceInstructEmbeddings(cache_folder=self.model_cache,
model_name="sentence-transformers/all-MiniLM-L6-v2"))
self.history = history
print("Using instruction: " + instruction)
self.instruction = instruction
def provide_context_for_query(self, query: str, smart_search: bool = False):
""" Provide context for query
Args:
query(str): The message query
smart_search(bool): If using smart search
Returns:
str: Query with context
"""
if smart_search==True:
system="""
You are an AI that provides assistance in database search.
Please translate the user's query to a list of search keywords
that will be helpful in retrieving documents from a database
based on similarity.
The language of the keywords should match the language of the documents:
"""+doc_language+"""\n
Answer with a list of keywords.
"""
query=self.chat(
[SystemMessage(content=system),
HumanMessage(content=query)]
).content
docs = self.vector_store.similarity_search(query)
context = "\n---\n".join(doc.page_content for doc in docs)
return context
# Define functions for memory management
def purge_memory(self, messages: list):
""" Purge memory to save tokens
Args:
messages(list): The list of messages
Returns:
int: The token count
"""
token_count = self.token_counter(messages)
if (len(messages)>1):
while (token_count > int(os.getenv("MAX_PROMPT_TOKENS"))):
print(token_count)
# Print purged message for testing purposes
# print("Purged the following message:\n" + messages[1])
messages.pop(1)
token_count = self.token_counter(messages)
return token_count
# PROMPT TOKEN COUNT DOES NOT EXACTLY MATCH OPENAI COUNT
def token_counter(self, messages: list):
""" Count tokens
Args:
messages(list): The list of messages
Returns:
int: The token count
"""
# print("Counting tokens based on: " + current_model)
if self.model_name == "gpt-4":
encoding = tiktoken.encoding_for_model("gpt-4")
else:
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
concatenated_content = ''.join([message.content for message in messages])
token_count = len(encoding.encode(concatenated_content))
return token_count
def call_chat(self, query: str):
""" Call chat
Args:
query(str): The message query
Returns:
str: The results content
"""
# Search vector store for relevant documents
context = self.provide_context_for_query(query)
# Combine instructions + context to create system instruction for the chat model
system_instruction = self.instruction + context
# Convert message history to list of message objects
print("History: " + str(self.history))
messages_history = []
i = 0
for message in self.history:
if i % 2 == 0:
messages_history.append(HumanMessage(content=message))
else:
messages_history.append(AIMessage(content=message))
i += 1
print("Messages history: " + str(messages_history))
# Initialize message list
messages = [SystemMessage(content=system_instruction)]
for message in messages_history:
messages.append(message)
messages.append(HumanMessage(content=query))
# Purge memory to save tokens
# Current implementation is not ideal.
# Gradio keeps the entire history in memory
# Therefore, the messages memory is re-purged on every call once token count max_tokens
# print("Message purge")
token_count = self.purge_memory(messages)
# print("First message: \n" + str(messages[1].type))
# print(str(messages))
# print(token_count)
if self.llm_provider != 'null':
results = self.chat(messages)
result_tokens = self.token_counter([results])
print(f"Prompt tokens: {token_count}")
print(f"Completion tokens: {result_tokens}")
total_tokens = token_count+result_tokens
print(f"Total tokens: {total_tokens}")
results_content = results.content
else:
# debug mode:
results_content = context
return results_content