Skip to content

Commit 21e29b3

Browse files
add sources
1 parent f587462 commit 21e29b3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+13345
-0
lines changed

data/dataset.py

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
################################################################################
2+
#
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
from datasets import load_dataset
19+
from termcolor import colored
20+
import random
21+
import numpy as np
22+
23+
# RULER
24+
from .metrics import needle_score, string_match_part, multi_number, multi_words
25+
26+
# NIAH
27+
from data.utils import generate_random_number, read_context_files, create_contexts, NIAH_TEMPLATE, RANDOM_NEEDLE_CITIES
28+
29+
METRICS_FN = {
30+
'niah': needle_score,
31+
'multi': multi_number,
32+
'vt': multi_words,
33+
'cwe': multi_words,
34+
'fwe': multi_words,
35+
'qa': string_match_part,
36+
}
37+
38+
GEN_LEN = {
39+
'niah': 64,
40+
'vt': 30,
41+
'cwe': 120,
42+
'fwe': 50,
43+
'qa': 32,
44+
}
45+
46+
DATADIR = {
47+
'ruler': 'data/ruler/data',
48+
'niah': 'data/niah/data',
49+
}
50+
51+
class Dataset:
52+
def __init__(self, dataset_name, tokenizer, datalen, num_samples, rank=0, world_size=1):
53+
self.dataset_name = dataset_name
54+
self.tokenizer = tokenizer
55+
self.datalen = datalen
56+
self.num_samples = num_samples
57+
self.rank = rank
58+
self.world_size = world_size
59+
self.is_sharded = False
60+
61+
if dataset_name == 'niah':
62+
self.tokenized_prompts, self.gt, self.ctx_len, self.depth_pct = self.get_dataset()
63+
else:
64+
self.tokenized_prompts, self.gt = self.get_dataset()
65+
66+
self.num_samples = len(self.tokenized_prompts)
67+
self.gen_len = self.get_gen_len()
68+
self.metric = self.get_metric()
69+
70+
def __str__(self) -> str:
71+
return f"Dataset: {self.dataset_name}, Num Samples: {self.num_samples}, Gen Len: {self.gen_len}, DataLen: {self.datalen}"
72+
73+
def __repr__(self) -> str:
74+
return f"Dataset: {self.dataset_name}, Num Samples: {self.num_samples}, Gen Len: {self.gen_len}, DataLen: {self.datalen}"
75+
76+
def __len__(self) -> int:
77+
return self.num_samples
78+
79+
def shard(self, rank, world_size):
80+
if world_size > 1:
81+
shard_size = self.num_samples // world_size
82+
start = rank * shard_size
83+
end = start + shard_size if rank != world_size - 1 else self.num_samples
84+
shard_tokenized_prompts, shard_gt = self.tokenized_prompts[start:end], self.gt[start:end]
85+
self.tokenized_prompts = shard_tokenized_prompts
86+
self.gt = shard_gt
87+
self.num_samples = len(shard_tokenized_prompts)
88+
89+
self.is_sharded = True
90+
91+
def get_gen_len(self):
92+
if 'niah' == self.dataset_name:
93+
return 10
94+
elif 'niah' in self.dataset_name:
95+
return 128
96+
elif 'vt' in self.dataset_name:
97+
return 30
98+
elif 'cwe' in self.dataset_name:
99+
return 120
100+
elif 'fwe' in self.dataset_name:
101+
return 50
102+
elif 'qa' in self.dataset_name:
103+
return 32
104+
else:
105+
raise Exception("Gen len not found")
106+
107+
def __getitem__(self, idx):
108+
if 'persona' in self.dataset_name:
109+
return self.tokenized_prompts[idx], self.queries[idx], self.gt[idx]
110+
return self.tokenized_prompts[idx], self.gt[idx]
111+
112+
def get_metric(self):
113+
if 'multiquery' in self.dataset_name or 'multivalue' in self.dataset_name:
114+
return METRICS_FN['multi']
115+
elif 'niah' in self.dataset_name:
116+
return METRICS_FN['niah']
117+
elif 'vt' in self.dataset_name:
118+
return METRICS_FN['vt']
119+
elif 'cwe' in self.dataset_name:
120+
return METRICS_FN['cwe']
121+
elif 'fwe' in self.dataset_name:
122+
return METRICS_FN['fwe']
123+
elif 'qa' in self.dataset_name:
124+
return METRICS_FN['qa']
125+
else:
126+
raise Exception("Metric not found")
127+
128+
def get_dataset(self):
129+
if 'ruler' in self.dataset_name: # ruler/xxx
130+
task = self.dataset_name.split('/')[-1]
131+
assert self.datalen in [8*1024, 16*1024, 32*1024, 64*1024, 128*1024, 256*1024], "Only support datalen of 16k, 32k, 64k, 128k"
132+
133+
if 'llama-3' in self.tokenizer.name_or_path.lower():
134+
model_dir = 'llama-3'
135+
elif 'yi' in self.tokenizer.name_or_path.lower():
136+
model_dir = 'yi'
137+
elif 'lwm' in self.tokenizer.name_or_path.lower():
138+
model_dir = 'lwm'
139+
elif 'glm' in self.tokenizer.name_or_path.lower():
140+
model_dir = 'glm'
141+
elif 'qwen' in self.tokenizer.name_or_path.lower():
142+
model_dir = 'qwen'
143+
elif 'phi' in self.tokenizer.name_or_path.lower():
144+
model_dir = 'phi'
145+
else:
146+
raise Exception("Model not found", self.tokenizer.name_or_path)
147+
148+
dataset = load_dataset("json", data_files=f'{DATADIR["ruler"]}/{model_dir}/{self.datalen}/{task}/validation.jsonl', split='train')
149+
if self.num_samples > 0:
150+
self.num_samples = min(self.num_samples, len(dataset))
151+
else:
152+
self.num_samples = len(dataset)
153+
tokenized_prompts = []
154+
gt = []
155+
156+
for i in range(self.num_samples):
157+
input_text = dataset[i]['input']
158+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=False)
159+
tokenized_prompts.append(input_ids)
160+
gt.append(dataset[i]['outputs'])
161+
162+
return tokenized_prompts, gt
163+
164+
elif self.dataset_name == 'niah':
165+
print(colored(f"[Warning] NIAH dataset cannot set # samples, it is up to world_size, which is set to {self.world_size}", 'red'))
166+
167+
haystack_file = f'{DATADIR["niah"]}/pg19_mini.jsonl'
168+
context_lengths_min = 16*1024
169+
context_lengths_max = self.datalen
170+
n_context_length_intervals = 15
171+
n_document_depth_intervals = 10 # position of the needle in the haystack
172+
n_rounds = 1 # max(1, 4 // self.world_size) # 8 rounds in total assume we have 8xGPUs
173+
needle = "\nThe special magic {city} number is: {rnd_number}\n"
174+
retrieval_question="What is the special magic {} number?"
175+
rnd_number_digits = 7
176+
177+
context_lengths = np.round(
178+
np.linspace(
179+
context_lengths_min,
180+
context_lengths_max,
181+
num=n_context_length_intervals,
182+
endpoint=True,
183+
)
184+
).astype(int)
185+
186+
document_depth_percents = np.round( # we use linear scale here
187+
np.linspace(
188+
0,
189+
100,
190+
num=n_document_depth_intervals,
191+
endpoint=True,
192+
)
193+
).astype(int)
194+
195+
self.is_sharded = True # we shard the data during init dataset
196+
197+
full_contexts = read_context_files(n=n_rounds, context_lengths=context_lengths, haystack_file=haystack_file, tokenizer=self.tokenizer)
198+
full_tokens = [
199+
self.tokenizer.encode(full_context, add_special_tokens=False) for full_context in full_contexts
200+
]
201+
202+
tokenized_prompts = []
203+
gt = []
204+
ctx_len = []
205+
depth_pct = []
206+
207+
for context_length in context_lengths:
208+
trim_contexts = [
209+
self.tokenizer.decode(full_token[:context_length], skip_special_tokens=True)
210+
for full_token in full_tokens
211+
]
212+
contexts = []
213+
for depth_percent in document_depth_percents:
214+
for i in range(n_rounds):
215+
random_city = random.choice(RANDOM_NEEDLE_CITIES)
216+
insert_needle = True
217+
needle_rnd_number = str(generate_random_number(rnd_number_digits))
218+
context = create_contexts(
219+
needle_rnd_number=needle_rnd_number,
220+
insert_needle=insert_needle,
221+
random_city=random_city,
222+
trim_context=trim_contexts[i],
223+
context_length=context_length,
224+
depth_percent=depth_percent,
225+
needle=needle,
226+
retrieval_question=retrieval_question,
227+
tokenizer=self.tokenizer,
228+
final_context_length_buffer=32,
229+
)
230+
contexts.append(context)
231+
232+
for context in contexts:
233+
prompt = NIAH_TEMPLATE.format(
234+
context=context["context"], question=context["question"]
235+
)
236+
input_tensor = self.tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
237+
tokenized_prompts.append(input_tensor.input_ids)
238+
gt.append(context["needle_rnd_number"])
239+
ctx_len.append(context["context_length"])
240+
depth_pct.append(context["depth_percent"])
241+
242+
return tokenized_prompts, gt, ctx_len, depth_pct
243+
244+
else:
245+
raise ValueError(f"Dataset {self.dataset_name} not found, please choose in ruler, persona, infini_bench, needle, niah, long_bench")

data/metrics.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
################################################################################
2+
#
3+
# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
################################################################################
17+
18+
import re
19+
import string
20+
21+
def normalize_answer(s):
22+
"""Lower text and remove punctuation, articles and extra whitespace."""
23+
24+
def remove_articles(text):
25+
return re.sub(r"\b(a|an|the)\b", " ", text)
26+
27+
def white_space_fix(text):
28+
return " ".join(text.split())
29+
30+
def remove_punc(text):
31+
exclude = set(string.punctuation)
32+
return "".join(ch for ch in text if ch not in exclude)
33+
34+
def lower(text):
35+
return text.lower()
36+
37+
return white_space_fix(remove_articles(remove_punc(lower(s))))
38+
39+
40+
def postprocess_pred(predict_str: str):
41+
42+
predict_str = predict_str.strip().replace('<|eot_id|>', '').replace('</s>', '').replace('</s', '').replace('</', '')
43+
44+
# Remove all non-printable characters
45+
np_pattern = re.compile(r'[\x00-\x1f]')
46+
predict_str = np_pattern.sub('\n', predict_str).strip()
47+
48+
return predict_str
49+
50+
def string_match_part(preds, refs):
51+
preds = postprocess_pred(preds)
52+
if isinstance(refs, str):
53+
refs = [refs]
54+
score_ref_in_pred = max([1.0 if r.lower() in preds.lower() else 0.0 for r in refs])
55+
score_pred_in_ref = max([1.0 if preds.lower() in r.lower() else 0.0 for r in refs])
56+
score = max(score_ref_in_pred, score_pred_in_ref)
57+
return round(score, 2)
58+
59+
def multi_number(prediction: str, ground_truth: list) -> float:
60+
assert type(prediction) == str, f"Prediction is not a string, but {prediction}, type: {type(prediction)}"
61+
assert type(ground_truth) == list, f"Ground truth is not a list, but {ground_truth}, type: {type(ground_truth)}"
62+
prediction = normalize_answer(prediction)
63+
prediction_list = re.findall(r'\d+', prediction)
64+
hits = [item for item in ground_truth if item in prediction_list]
65+
hit_rate = len(hits) / len(ground_truth)
66+
67+
return hit_rate
68+
69+
def multi_words(prediction: str, ground_truth: list) -> float:
70+
prediction = prediction.lower()
71+
ground_truth = [gt.lower() for gt in ground_truth]
72+
prediction_list = re.findall(r'\b\w+\b', prediction)
73+
hits = [item for item in ground_truth if item in prediction_list]
74+
hit_rate = len(hits) / len(ground_truth)
75+
76+
return hit_rate
77+
78+
def needle_score(prediction, ground_truth):
79+
assert type(prediction) == str, f"Prediction is not a string, but {prediction}, type: {type(prediction)}"
80+
assert type(ground_truth) == str, f"Ground truth is not a string, but {ground_truth}, type: {type(ground_truth)}"
81+
prediction = normalize_answer(postprocess_pred(prediction))
82+
ground_truth = normalize_answer(ground_truth)
83+
min_length = min(len(prediction), len(ground_truth))
84+
min_length = len(ground_truth)
85+
score = float((prediction[:min_length] == ground_truth[:min_length]))
86+
pred_list = prediction.split()
87+
score = max(float(ground_truth in pred_list), score)
88+
return score

0 commit comments

Comments
 (0)