|
1 |
| -from typing import Optional, Callable, Union |
2 |
| -from dataclasses import dataclass |
3 |
| -from functools import partial |
| 1 | +from typing import Callable |
4 | 2 |
|
5 |
| -import torch |
6 |
| -import transformers |
7 |
| -import ochat.models |
| 3 | +from pydantic import BaseModel |
8 | 4 |
|
9 | 5 |
|
10 |
| -@dataclass |
11 |
| -class ModelConfig: |
12 |
| - name: str |
13 |
| - |
14 |
| - # Prompt |
15 |
| - role_prefix: Union[dict, Callable] |
16 |
| - ai_role: str |
17 |
| - eot_token: str |
18 |
| - bos_token: Optional[str] = None |
19 |
| - |
20 |
| - condition_fn: Optional[Callable] = None |
21 |
| - |
| 6 | +class ModelConfig(BaseModel): |
22 | 7 | # Model
|
23 |
| - model_max_context: Optional[int] = None |
24 |
| - model_tokenizer_create: Optional[Callable] = None |
25 |
| - model_create_for_training: Optional[Callable] = None |
26 |
| - |
27 |
| - # Get template |
28 |
| - def generate_conversation_template(self, tokenize_fn, tokenize_special_fn, system_prompt, message_list, message_props=None): |
29 |
| - tokens = [] |
30 |
| - masks = [] |
31 |
| - weights = [] |
32 |
| - |
33 |
| - # begin of sentence (bos) |
34 |
| - if self.bos_token: |
35 |
| - t = tokenize_special_fn(self.bos_token) |
36 |
| - |
37 |
| - tokens.extend([t]) |
38 |
| - masks.extend([False]) |
39 |
| - weights.extend([0.]) |
40 |
| - |
41 |
| - # Condition |
42 |
| - if self.condition_fn is not None: |
43 |
| - t = tokenize_fn(self.condition_fn(message_props)) + [tokenize_special_fn(self.eot_token)] |
44 |
| - |
45 |
| - tokens.extend(t) |
46 |
| - masks.extend([False] * len(t)) |
47 |
| - weights.extend([0.] * len(t)) |
48 |
| - |
49 |
| - # System |
50 |
| - if system_prompt: |
51 |
| - t = tokenize_fn(system_prompt) + [tokenize_special_fn(self.eot_token)] |
52 |
| - |
53 |
| - tokens.extend(t) |
54 |
| - masks.extend([False] * len(t)) |
55 |
| - weights.extend([0.] * len(t)) |
56 |
| - |
57 |
| - # Messages |
58 |
| - for idx, message in enumerate(message_list): |
59 |
| - # Prefix |
60 |
| - if callable(self.role_prefix): |
61 |
| - role_prefix = self.role_prefix(message["from"], message_props) |
62 |
| - else: |
63 |
| - role_prefix = self.role_prefix[message["from"]] |
64 |
| - |
65 |
| - t = tokenize_fn(role_prefix) |
66 |
| - tokens.extend(t) |
67 |
| - masks.extend([False] * len(t)) |
68 |
| - weights.extend([0.] * len(t)) |
69 |
| - |
70 |
| - # Message |
71 |
| - if "value" in message: |
72 |
| - t = tokenize_fn(message["value"]) + [tokenize_special_fn(self.eot_token)] |
73 |
| - |
74 |
| - # determine weights |
75 |
| - use_loss = (message["from"] == self.ai_role) and bool(message.get("use_loss", True)) |
76 |
| - w = 1.0 if use_loss else 0.0 |
77 |
| - |
78 |
| - if message_props is not None and ("weight" in message_props): |
79 |
| - w *= message_props["weight"] |
80 |
| - |
81 |
| - tokens.extend(t) |
82 |
| - masks.extend([use_loss] * len(t)) |
83 |
| - weights.extend([w] * len(t)) |
84 |
| - else: |
85 |
| - assert idx == len(message_list) - 1, "Empty message for completion must be on the last." |
86 |
| - |
87 |
| - return tokens, masks, weights |
88 |
| - |
89 |
| - |
90 |
| -def _v2_conditional_prefix(from_role, props): |
91 |
| - human_prefix = "User:" |
92 |
| - gpt4_prefix = "Assistant GPT4:" |
93 |
| - other_prefix = "Assistant GPT3:" |
94 |
| - |
95 |
| - if from_role == "human": |
96 |
| - return human_prefix |
97 |
| - |
98 |
| - if from_role == "gpt": |
99 |
| - if props is None: |
100 |
| - return gpt4_prefix # inference using gpt-4 prefix |
101 |
| - |
102 |
| - return gpt4_prefix if props["is_gpt4"] else other_prefix |
103 |
| - |
104 |
| - raise NotImplementedError(f"Unknown role {from_role}") |
105 |
| - |
106 |
| - |
107 |
| -def _v3_2_conditional_prefix(from_role, props): |
108 |
| - gpt3_prefixes = { |
109 |
| - "human": "GPT3 User:", |
110 |
| - "gpt": "GPT3 Assistant:" |
111 |
| - } |
112 |
| - gpt4_prefixes = { |
113 |
| - "human": "GPT4 User:", |
114 |
| - "gpt": "GPT4 Assistant:" |
115 |
| - } |
116 |
| - prefixes = gpt4_prefixes if props is None or props["is_gpt4"] else gpt3_prefixes |
117 |
| - |
118 |
| - return prefixes[from_role] |
119 |
| - |
120 |
| - |
121 |
| -def _v3_condition(props): |
122 |
| - gpt4_condition = "Assistant is GPT4" |
123 |
| - gpt3_condition = "Assistant is GPT3" |
124 |
| - |
125 |
| - if props is None: |
126 |
| - return gpt4_condition |
127 |
| - |
128 |
| - return gpt4_condition if props["is_gpt4"] else gpt3_condition |
129 |
| - |
130 |
| - |
131 |
| -MODEL_CONFIG_MAP = { |
132 |
| - ################# Llama 2 based models |
133 |
| - # OpenChat V3.2 |
134 |
| - "openchat_v3.2": ModelConfig( |
135 |
| - name="OpenChat V3.2 Llama 2", |
136 |
| - |
137 |
| - # Prompt |
138 |
| - role_prefix=_v3_2_conditional_prefix, |
139 |
| - ai_role="gpt", |
140 |
| - eot_token="<|end_of_turn|>", |
141 |
| - bos_token="<s>", |
142 |
| - |
143 |
| - # Tokenize |
144 |
| - model_max_context=4096, |
145 |
| - model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, |
146 |
| - use_fast=False, |
147 |
| - legacy=True), |
148 |
| - model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained, |
149 |
| - low_cpu_mem_usage=True, |
150 |
| - torch_dtype=torch.bfloat16), |
151 |
| - ), |
152 |
| - |
153 |
| - "openchat_v3.1_llama2": ModelConfig( |
154 |
| - name="OpenChat V3.1 Llama 2", |
155 |
| - |
156 |
| - # Prompt |
157 |
| - role_prefix={ |
158 |
| - "human": "User:", |
159 |
| - "gpt": "Assistant:" |
160 |
| - }, |
161 |
| - ai_role="gpt", |
162 |
| - eot_token="<|end_of_turn|>", |
163 |
| - bos_token="<s>", |
164 |
| - |
165 |
| - condition_fn=_v3_condition, |
166 |
| - |
167 |
| - # Tokenize |
168 |
| - model_max_context=4096, |
169 |
| - model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, |
170 |
| - use_fast=False, |
171 |
| - legacy=True), |
172 |
| - model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained, |
173 |
| - low_cpu_mem_usage=True, |
174 |
| - torch_dtype=torch.bfloat16), |
175 |
| - ), |
176 |
| - |
177 |
| - # OpenChat V2 |
178 |
| - "openchat_v2_llama2": ModelConfig( |
179 |
| - name="OpenChat V2 Llama 2", |
180 |
| - |
181 |
| - # Prompt |
182 |
| - role_prefix=_v2_conditional_prefix, |
183 |
| - ai_role="gpt", |
184 |
| - eot_token="<|end_of_turn|>", |
185 |
| - bos_token="<s>", |
186 |
| - |
187 |
| - # Tokenize |
188 |
| - model_max_context=4096, |
189 |
| - model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, |
190 |
| - use_fast=False, |
191 |
| - legacy=True), |
192 |
| - model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained, |
193 |
| - low_cpu_mem_usage=True, |
194 |
| - torch_dtype=torch.bfloat16), |
195 |
| - ), |
196 |
| - |
197 |
| - # OpenChat |
198 |
| - "openchat_llama2": ModelConfig( |
199 |
| - name="OpenChat V1 Llama 2", |
200 |
| - |
201 |
| - # Prompt |
202 |
| - role_prefix={ |
203 |
| - "human": "User:", |
204 |
| - "gpt": "Assistant:" |
205 |
| - }, |
206 |
| - ai_role="gpt", |
207 |
| - eot_token="<|end_of_turn|>", |
208 |
| - bos_token="<s>", |
| 8 | + model_max_context: int |
| 9 | + model_tokenizer_create: Callable |
| 10 | + model_create_for_training: Callable |
209 | 11 |
|
210 |
| - # Tokenize |
211 |
| - model_max_context=4096, |
212 |
| - model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, |
213 |
| - use_fast=False, |
214 |
| - legacy=True), |
215 |
| - model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained, |
216 |
| - low_cpu_mem_usage=True, |
217 |
| - torch_dtype=torch.bfloat16), |
218 |
| - ) |
219 |
| -} |
| 12 | + # conversation template |
| 13 | + conversation_template: Callable |
0 commit comments