-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathprompter.py
292 lines (254 loc) · 9.54 KB
/
prompter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import requests
import json
import yaml
import asyncio
import os
# Constants for API endpoints
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" # Default URL for OpenRouter API
HF_URL = "https://api-inference.huggingface.co/v1/chat/completions" # Alternate URL for HuggingFace API
def worker(args):
"""
Worker function to handle a single prompt with its index.
This needs to be a top-level function for multiprocessing to work.
"""
instance, index, prompt = args
prompt_dict = [{"role": "user", "content": instance.make_prompt(prompt)}]
response = instance.generate(prompt_dicts=prompt_dict)
return index, response
class Prompter:
"""
A class for making requests to LLM APIs
"""
def __init__(
self,
base_url=HF_URL, # Default base URL for API requests
):
"""
Initializes the Prompter instance with default or specified values.
Args:
base_url (str): The base URL of the API endpoint. Defaults to OpenRouter's URL.
"""
self.base_url = base_url # Base URL for API requests
self.token = None # API authorization token
self.model_name = None # Name of the model to use
self.generation_args = {} # Additional arguments for generation
self.logged = False # Reserved flag, possibly for logging activity (unused here)
self.prompt_template = None
# =============================================
def _set_base_url(
self,
base_url: str, # New base URL for API requests
):
"""
Sets the base URL for API requests.
Args:
base_url (str): The new base URL for the API.
"""
self.base_url = base_url
return
def _set_token(
self,
token: str, # API authorization token
):
"""
Sets the authorization token for API requests.
Args:
token (str): The API token.
"""
self.token = token
return
def _set_model(
self,
model_name: str, # Model identifier for the API
):
"""
Sets the model name to use for requests.
Args:
model_name (str): The model identifier (e.g., "gpt-3.5-turbo").
"""
self.model_name = model_name
return
def _set_generation_args(
self,
generation_args: dict = {}, # Additional generation parameters
):
"""
Sets additional generation arguments to customize API behavior.
Args:
generation_args (dict): Dictionary of generation parameters (e.g., temperature, max_tokens).
"""
self.generation_args = generation_args
return
def _update_generation_arg(
self,
key,
value,
):
"""
Update a key of the generation args
"""
self.generation_args[key] = value
return
# =============================================
def generate(
self,
prompt_dicts: list[dict], # List of message dictionaries defining the conversation
stream: bool = False, # Flag for streaming responses (not implemented in this method)
):
"""
Sends a request to the API to generate a response based on the given prompts.
Args:
prompt_dicts (list[dict]): A list of message dictionaries containing the prompts.
stream (bool): Whether to stream the response. Defaults to False.
Returns:
str: The content of the response message.
"""
# Make a POST request to the API
response = requests.post(
url=self.base_url,
headers={
"Authorization": f"Bearer {self.token}", # Authorization header with token
},
json={
"model": self.model_name, # Model name for the request
"messages": prompt_dicts, # Conversation history/messages
**self.generation_args # Additional generation arguments
}
)
# Parse the API response and return the content of the first choice
return json.loads(response.content)["choices"][0]["message"]["content"]
async def async_generate(
self,
prompt_dicts,
index:int,
):
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, self.generate, prompt_dicts)
return index, response
async def async_generate_batch(
self,
prompts: list, # List of prompt strings
batch_size: int = 16,
max_retries: int = 0, # Maximum retries for failed prompts
error_callback=None, # Function to report errors to the Streamlit UI
):
"""
Generates responses for a batch of prompts with parallel requests and error handling.
Args:
prompts (list): A list of prompt strings.
batch_size (int): Number of prompts to process in parallel.
max_retries (int): Maximum number of retries for failed prompts.
error_callback (function): A callback function to log or display errors in the Streamlit app.
Returns:
list: A list of response strings for each prompt. Failed prompts are replaced with error messages.
"""
if batch_size is None:
batch_size = len(prompts)
indexed_prompts = [(i, p) for i, p in enumerate(prompts)]
results = [""] * len(prompts) # Placeholder for results
errors = {}
for i in range(0, len(indexed_prompts), batch_size):
batch = indexed_prompts[i:i + batch_size] # Get the next batch
# Create async tasks for the batch
tasks = [
self.async_generate(
prompt_dicts=[{"role": "user", "content": self.make_prompt(id_prompt[1])}],
index=id_prompt[0]
) for id_prompt in batch
]
# Run tasks concurrently and handle errors
for retry in range(max_retries + 1):
try:
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
break # Exit retry loop on success
except Exception as e:
if retry < max_retries:
await asyncio.sleep(2 ** retry) # Exponential backoff
else:
raise e
# Process results and handle errors
for idx, result in enumerate(batch_results):
original_index = batch[idx][0]
if isinstance(result, Exception):
error_message = f"Error: {str(result)}"
errors[original_index] = error_message
results[original_index] = ""
# Call the error callback if provided
if error_callback:
error_callback(original_index, error_message)
else:
results[original_index] = result[1]
return results
def generate_batch(
self,
prompts:list[str],
**kwargs
):
# Create and run the event loop if not in Jupyter
return asyncio.run(self.async_generate_batch(prompts, **kwargs))
# =============================================
def generate_more(
self,
initial_prompt: str = None, # Initial prompt
more_prompt: int|list = "Make it MORE",
n_more: int = 2,
sleep_time = 5,
):
"""
Iteratively refine the response by prompting the model to make it MORE.
Args:
initial_prompt (str): The initial prompt to start the conversation.
more_prompt (str): The prompt used iteratively in the discussion after the initial turn.
n_more (int): Number of refinement iterations.
Returns:
str: The final discussion history after n_more turns.
"""
import time
if type(more_prompt)==list:
assert(len(more_prompt)==n_more)
else:
more_prompt = [more_prompt]*n_more
prompt_dicts = [{"role":"user","content":initial_prompt}]
for k in range(n_more+1):
if k == 0:
# Initial prompt
if not initial_prompt is None:
response = self.generate(prompt_dicts)
else:
# "MORE" utterances
prompt_dicts.append({
"role":"user",
"content":more_prompt[k-1]
})
response = self.generate(prompt_dicts)
# Append a assistant's answer from the prompt
prompt_dicts.append({
"role":"assistant",
"content":response
})
time.sleep(spleep_time)
return prompt_dicts
# =============================================
def make_prompt(
self, prompt:str,
):
if not self.prompt_template is None:
return self.prompt_template.format(text=prompt)
return prompt
def load_prompt_template(
self,
yaml_template:str=None,
):
if yaml_template is None:
self.prompt_template = None
else:
# Load the YAML file
#with open(yaml_template, "r") as file:
yaml_template = yaml.safe_load(yaml_template)
prompt_tmp = "\n".join([
yaml_template["prefix"],
yaml_template["core_prompt"],
yaml_template["suffix"]
])
self.prompt_template = prompt_tmp
return