-
-
Notifications
You must be signed in to change notification settings - Fork 221
/
example_cfg.py
92 lines (66 loc) · 3.25 KB
/
example_cfg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import torch
import torch.nn.functional as F
import os, glob
import cuda_ext
# Directory containing model, tokenizer, generator
model_directory = "/mnt/str/models/_test_models/TheBloke_Llama-2-13B-chat-GPTQ/"
# Locate files we need within that directory
tokenizer_path = os.path.join(model_directory, "tokenizer.model")
model_config_path = os.path.join(model_directory, "config.json")
st_pattern = os.path.join(model_directory, "*.safetensors")
model_path = glob.glob(st_pattern)
# Create config, model, tokenizer and generator
config = ExLlamaConfig(model_config_path) # create config from config.json
config.model_path = model_path # supply path to model weights file
model = ExLlama(config) # create ExLlama instance and load the weights
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file
cache = ExLlamaCache(model, batch_size = 2) # create cache for inference
generator = ExLlamaGenerator(model, tokenizer, cache) # create generator
# Configure generator
generator.settings.token_repetition_penalty_max = 1.15
generator.settings.temperature = 0.95
generator.settings.top_k = 40
generator.settings.top_p = 0.75
# generator.settings.typical = 0.95
# Prompts to mix
f1 = \
"""[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>
{prompt}[/INST]"""
f2 = \
"""[INST] <<SYS>>
<</SYS>>
You are a rude and obnoxious assistant. You hate everything and everyone.
{prompt}[/INST]"""
prompts = \
[
f1.replace("{prompt}", "Tell me about Homer Simpson"),
f2.replace("{prompt}", "Tell me about Homer Simpson"),
]
def generate_cfg(prompts, alpha, max_new_tokens):
ids, mask = tokenizer.encode(prompts, return_mask = True)
generator.gen_begin(ids, mask = mask)
# Sampling loop
for _ in range(max_new_tokens):
logits = model.forward(generator.sequence[:, -1:], cache, input_mask = mask)
generator.apply_rep_penalty(logits)
logits = F.log_softmax(logits, dim = -1)
logits_mixed = (1 - alpha) * logits[0] + alpha * logits[1]
sampled_token, _ = generator.sample_current(logits_mixed)
if sampled_token.item() == tokenizer.eos_token_id: break
batch_token = sampled_token.repeat(2, 1)
generator.gen_accept_token(batch_token)
output = tokenizer.decode(generator.sequence[0])
return output
for i in range(10):
alpha = i / 5.0 - 0.4
print()
print(f"--------------------------------------")
print(f"alpha = {alpha:.1f}")
print(f"--------------------------------------")
output = generate_cfg(prompts, alpha, 200)
print(output[len(prompts[0]):].strip())