Skip to content

Commit 30e5dae

Browse files
committed
[refactor] 3.5
1 parent 7e749fc commit 30e5dae

17 files changed

+428
-913
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ wandb/
66

77
# Old
88
old/
9+
temp/
10+
profiler/
911

1012
# Logs
1113
logs/

ochat/config/__init__.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from functools import partial
2+
3+
import torch
4+
import transformers
5+
6+
from ochat.config.model_config import ModelConfig
7+
from ochat.config.conversation_template import Message, Conversation, ConversationTemplate
8+
import ochat.models
9+
10+
11+
_V3_2_PREFIXES = {
12+
# ShareGPT & OpenAI mapping
13+
14+
"human": "User:",
15+
"user": "User:",
16+
"gpt": "Assistant:",
17+
"assistant": "Assistant:"
18+
}
19+
20+
21+
def _v3_2_role_prefix(from_role, condition):
22+
return f"{condition} {_V3_2_PREFIXES[from_role]}".strip()
23+
24+
25+
MODEL_CONFIG_MAP = {
26+
# OpenChat V3.2
27+
"openchat_v3.2": ModelConfig(
28+
# Model
29+
model_max_context=4096,
30+
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
31+
use_fast=False,
32+
legacy=False),
33+
model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
34+
low_cpu_mem_usage=True,
35+
torch_dtype=torch.bfloat16),
36+
37+
# Conversation Template
38+
conversation_template=partial(ConversationTemplate,
39+
role_prefix=_v3_2_role_prefix,
40+
eot="<|end_of_turn|>",
41+
inference_condition="GPT4")
42+
)
43+
}

ochat/config/conversation_template.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Optional, Callable, Iterable, List, Dict
2+
import re
3+
4+
from pydantic import BaseModel, Field
5+
6+
7+
class Message(BaseModel):
8+
role: str = Field(..., alias="from")
9+
value: str
10+
11+
weight: Optional[float] = None
12+
13+
14+
class Conversation(BaseModel):
15+
items: List[Message]
16+
17+
condition: Optional[str] = None
18+
system: str = ""
19+
20+
21+
class ConversationTemplate(BaseModel):
22+
tokenizer: Callable
23+
24+
# Prompt
25+
role_prefix: Callable
26+
eot: str
27+
28+
inference_condition: Optional[str] = None
29+
30+
# Private
31+
bos_tokens_: List[int]
32+
eot_tokens_: List[int]
33+
34+
def __init__(self, **data):
35+
tokenizer = data["tokenizer"]
36+
eot = data["eot"]
37+
bos_tokens_ = tokenizer("").input_ids
38+
eot_tokens_ = tokenizer(eot, add_special_tokens=False).input_ids
39+
40+
super().__init__(**data, bos_tokens_=bos_tokens_, eot_tokens_=eot_tokens_)
41+
42+
def safe_tokenize(self, strings: Iterable[str]) -> List[List[int]]:
43+
return self.tokenizer(strings, split_special_tokens=True, return_attention_mask=False, add_special_tokens=False).input_ids
44+
45+
def tokenize_conversations(self, conversations: Iterable[Conversation], inference: bool = False):
46+
# Pre-tokenize all conversations
47+
default_condition = self.inference_condition if inference else None
48+
49+
sys_mappings = set()
50+
role_mappings = set()
51+
all_text = []
52+
for conv in conversations:
53+
sys_mappings.add(conv.system)
54+
for msg in conv.items:
55+
role_mappings.add((msg.role, conv.condition or default_condition))
56+
all_text.append(msg.value)
57+
58+
sys_mappings = list(sys_mappings)
59+
role_mappings = list(role_mappings)
60+
61+
# Tokenize
62+
sys_mappings = dict(zip(sys_mappings, self.safe_tokenize(sys_mappings)))
63+
role_mappings = dict(zip(role_mappings, self.safe_tokenize([self.role_prefix(*args) for args in role_mappings])))
64+
all_text = self.safe_tokenize(all_text)
65+
66+
# Convert
67+
result_tokens = []
68+
result_weights = []
69+
all_text_idx = 0
70+
for conv in conversations:
71+
tokens = []
72+
weights = []
73+
74+
# bos tokens
75+
tokens.extend(self.bos_tokens_)
76+
weights.extend([0.] * len(self.bos_tokens_))
77+
78+
# System
79+
if conv.system:
80+
system = sys_mappings[conv.system]
81+
tokens.extend(system)
82+
weights.extend([0.] * len(system))
83+
84+
tokens.extend(self.eot_tokens_)
85+
weights.extend([0.] * len(self.eot_tokens_))
86+
87+
# Messages
88+
last_idx = len(conv.items) - 1
89+
for idx, msg in enumerate(conv.items):
90+
# Prefix
91+
role = role_mappings[(msg.role, conv.condition or default_condition)]
92+
tokens.extend(role)
93+
weights.extend([0.] * len(role))
94+
95+
# Message
96+
text = all_text[all_text_idx]
97+
all_text_idx += 1
98+
99+
if not inference:
100+
assert msg.weight is not None
101+
102+
tokens.extend(text)
103+
weights.extend([msg.weight] * len(text))
104+
105+
if not (inference and idx == last_idx): # Do not add EOT on last turn during inference
106+
tokens.extend(self.eot_tokens_)
107+
weights.extend([msg.weight] * len(self.eot_tokens_))
108+
109+
# Append result
110+
result_tokens.append(tokens)
111+
result_weights.append(weights)
112+
113+
# Sanity check
114+
assert all_text_idx == len(all_text)
115+
116+
return result_tokens, result_weights

ochat/config/model_config.py

+8-214
Original file line numberDiff line numberDiff line change
@@ -1,219 +1,13 @@
1-
from typing import Optional, Callable, Union
2-
from dataclasses import dataclass
3-
from functools import partial
1+
from typing import Callable
42

5-
import torch
6-
import transformers
7-
import ochat.models
3+
from pydantic import BaseModel
84

95

10-
@dataclass
11-
class ModelConfig:
12-
name: str
13-
14-
# Prompt
15-
role_prefix: Union[dict, Callable]
16-
ai_role: str
17-
eot_token: str
18-
bos_token: Optional[str] = None
19-
20-
condition_fn: Optional[Callable] = None
21-
6+
class ModelConfig(BaseModel):
227
# Model
23-
model_max_context: Optional[int] = None
24-
model_tokenizer_create: Optional[Callable] = None
25-
model_create_for_training: Optional[Callable] = None
26-
27-
# Get template
28-
def generate_conversation_template(self, tokenize_fn, tokenize_special_fn, system_prompt, message_list, message_props=None):
29-
tokens = []
30-
masks = []
31-
weights = []
32-
33-
# begin of sentence (bos)
34-
if self.bos_token:
35-
t = tokenize_special_fn(self.bos_token)
36-
37-
tokens.extend([t])
38-
masks.extend([False])
39-
weights.extend([0.])
40-
41-
# Condition
42-
if self.condition_fn is not None:
43-
t = tokenize_fn(self.condition_fn(message_props)) + [tokenize_special_fn(self.eot_token)]
44-
45-
tokens.extend(t)
46-
masks.extend([False] * len(t))
47-
weights.extend([0.] * len(t))
48-
49-
# System
50-
if system_prompt:
51-
t = tokenize_fn(system_prompt) + [tokenize_special_fn(self.eot_token)]
52-
53-
tokens.extend(t)
54-
masks.extend([False] * len(t))
55-
weights.extend([0.] * len(t))
56-
57-
# Messages
58-
for idx, message in enumerate(message_list):
59-
# Prefix
60-
if callable(self.role_prefix):
61-
role_prefix = self.role_prefix(message["from"], message_props)
62-
else:
63-
role_prefix = self.role_prefix[message["from"]]
64-
65-
t = tokenize_fn(role_prefix)
66-
tokens.extend(t)
67-
masks.extend([False] * len(t))
68-
weights.extend([0.] * len(t))
69-
70-
# Message
71-
if "value" in message:
72-
t = tokenize_fn(message["value"]) + [tokenize_special_fn(self.eot_token)]
73-
74-
# determine weights
75-
use_loss = (message["from"] == self.ai_role) and bool(message.get("use_loss", True))
76-
w = 1.0 if use_loss else 0.0
77-
78-
if message_props is not None and ("weight" in message_props):
79-
w *= message_props["weight"]
80-
81-
tokens.extend(t)
82-
masks.extend([use_loss] * len(t))
83-
weights.extend([w] * len(t))
84-
else:
85-
assert idx == len(message_list) - 1, "Empty message for completion must be on the last."
86-
87-
return tokens, masks, weights
88-
89-
90-
def _v2_conditional_prefix(from_role, props):
91-
human_prefix = "User:"
92-
gpt4_prefix = "Assistant GPT4:"
93-
other_prefix = "Assistant GPT3:"
94-
95-
if from_role == "human":
96-
return human_prefix
97-
98-
if from_role == "gpt":
99-
if props is None:
100-
return gpt4_prefix # inference using gpt-4 prefix
101-
102-
return gpt4_prefix if props["is_gpt4"] else other_prefix
103-
104-
raise NotImplementedError(f"Unknown role {from_role}")
105-
106-
107-
def _v3_2_conditional_prefix(from_role, props):
108-
gpt3_prefixes = {
109-
"human": "GPT3 User:",
110-
"gpt": "GPT3 Assistant:"
111-
}
112-
gpt4_prefixes = {
113-
"human": "GPT4 User:",
114-
"gpt": "GPT4 Assistant:"
115-
}
116-
prefixes = gpt4_prefixes if props is None or props["is_gpt4"] else gpt3_prefixes
117-
118-
return prefixes[from_role]
119-
120-
121-
def _v3_condition(props):
122-
gpt4_condition = "Assistant is GPT4"
123-
gpt3_condition = "Assistant is GPT3"
124-
125-
if props is None:
126-
return gpt4_condition
127-
128-
return gpt4_condition if props["is_gpt4"] else gpt3_condition
129-
130-
131-
MODEL_CONFIG_MAP = {
132-
################# Llama 2 based models
133-
# OpenChat V3.2
134-
"openchat_v3.2": ModelConfig(
135-
name="OpenChat V3.2 Llama 2",
136-
137-
# Prompt
138-
role_prefix=_v3_2_conditional_prefix,
139-
ai_role="gpt",
140-
eot_token="<|end_of_turn|>",
141-
bos_token="<s>",
142-
143-
# Tokenize
144-
model_max_context=4096,
145-
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
146-
use_fast=False,
147-
legacy=True),
148-
model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
149-
low_cpu_mem_usage=True,
150-
torch_dtype=torch.bfloat16),
151-
),
152-
153-
"openchat_v3.1_llama2": ModelConfig(
154-
name="OpenChat V3.1 Llama 2",
155-
156-
# Prompt
157-
role_prefix={
158-
"human": "User:",
159-
"gpt": "Assistant:"
160-
},
161-
ai_role="gpt",
162-
eot_token="<|end_of_turn|>",
163-
bos_token="<s>",
164-
165-
condition_fn=_v3_condition,
166-
167-
# Tokenize
168-
model_max_context=4096,
169-
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
170-
use_fast=False,
171-
legacy=True),
172-
model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
173-
low_cpu_mem_usage=True,
174-
torch_dtype=torch.bfloat16),
175-
),
176-
177-
# OpenChat V2
178-
"openchat_v2_llama2": ModelConfig(
179-
name="OpenChat V2 Llama 2",
180-
181-
# Prompt
182-
role_prefix=_v2_conditional_prefix,
183-
ai_role="gpt",
184-
eot_token="<|end_of_turn|>",
185-
bos_token="<s>",
186-
187-
# Tokenize
188-
model_max_context=4096,
189-
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
190-
use_fast=False,
191-
legacy=True),
192-
model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
193-
low_cpu_mem_usage=True,
194-
torch_dtype=torch.bfloat16),
195-
),
196-
197-
# OpenChat
198-
"openchat_llama2": ModelConfig(
199-
name="OpenChat V1 Llama 2",
200-
201-
# Prompt
202-
role_prefix={
203-
"human": "User:",
204-
"gpt": "Assistant:"
205-
},
206-
ai_role="gpt",
207-
eot_token="<|end_of_turn|>",
208-
bos_token="<s>",
8+
model_max_context: int
9+
model_tokenizer_create: Callable
10+
model_create_for_training: Callable
20911

210-
# Tokenize
211-
model_max_context=4096,
212-
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
213-
use_fast=False,
214-
legacy=True),
215-
model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
216-
low_cpu_mem_usage=True,
217-
torch_dtype=torch.bfloat16),
218-
)
219-
}
12+
# conversation template
13+
conversation_template: Callable

0 commit comments

Comments
 (0)