-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchatgpt.py
84 lines (73 loc) · 2.46 KB
/
chatgpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json.decoder
import openai
from common.enumeration.sql import LLM
import time
def init_chatgpt(OPENAI_API_KEY, OPENAI_GROUP_ID, model):
if model == LLM.TONG_YI_QIAN_WEN:
import dashscope
dashscope.api_key = OPENAI_API_KEY
else:
openai.api_key = OPENAI_API_KEY
openai.organization = OPENAI_GROUP_ID
def ask_completion(model, batch, temperature):
response = openai.Completion.create(
model=model,
prompt=batch,
temperature=temperature,
max_tokens=200,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=[";"]
)
response_clean = [_["text"] for _ in response["choices"]]
return dict(
response=response_clean,
**response["usage"]
)
def ask_chat(model, messages: list, temperature, n):
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=200,
n=n
)
response_clean = [choice["message"]["content"] for choice in response["choices"]]
if n == 1:
response_clean = response_clean[0]
return dict(
response=response_clean,
**response["usage"]
)
def ask_llm(model: str, batch: list, temperature: float, n:int):
n_repeat = 0
while True:
try:
if model in LLM.TASK_COMPLETIONS:
# TODO: self-consistency in this mode
assert n == 1
response = ask_completion(model, batch, temperature)
elif model in LLM.TASK_CHAT:
# batch size must be 1
assert len(batch) == 1, "batch must be 1 in this mode"
messages = [{"role": "user", "content": batch[0]}]
response = ask_chat(model, messages, temperature, n)
response['response'] = [response['response']]
break
except openai.error.RateLimitError:
n_repeat += 1
print(f"Repeat for the {n_repeat} times for RateLimitError", end="\n")
time.sleep(1)
continue
except json.decoder.JSONDecodeError:
n_repeat += 1
print(f"Repeat for the {n_repeat} times for JSONDecodeError", end="\n")
time.sleep(1)
continue
except Exception as e:
n_repeat += 1
print(f"Repeat for the {n_repeat} times for exception: {e}", end="\n")
time.sleep(1)
continue
return response