1
+ ################################################################################
2
+ #
3
+ # Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ ################################################################################
17
+
18
+ from datasets import load_dataset
19
+ from termcolor import colored
20
+ import random
21
+ import numpy as np
22
+
23
+ # RULER
24
+ from .metrics import needle_score , string_match_part , multi_number , multi_words
25
+
26
+ # NIAH
27
+ from data .utils import generate_random_number , read_context_files , create_contexts , NIAH_TEMPLATE , RANDOM_NEEDLE_CITIES
28
+
29
+ METRICS_FN = {
30
+ 'niah' : needle_score ,
31
+ 'multi' : multi_number ,
32
+ 'vt' : multi_words ,
33
+ 'cwe' : multi_words ,
34
+ 'fwe' : multi_words ,
35
+ 'qa' : string_match_part ,
36
+ }
37
+
38
+ GEN_LEN = {
39
+ 'niah' : 64 ,
40
+ 'vt' : 30 ,
41
+ 'cwe' : 120 ,
42
+ 'fwe' : 50 ,
43
+ 'qa' : 32 ,
44
+ }
45
+
46
+ DATADIR = {
47
+ 'ruler' : 'data/ruler/data' ,
48
+ 'niah' : 'data/niah/data' ,
49
+ }
50
+
51
+ class Dataset :
52
+ def __init__ (self , dataset_name , tokenizer , datalen , num_samples , rank = 0 , world_size = 1 ):
53
+ self .dataset_name = dataset_name
54
+ self .tokenizer = tokenizer
55
+ self .datalen = datalen
56
+ self .num_samples = num_samples
57
+ self .rank = rank
58
+ self .world_size = world_size
59
+ self .is_sharded = False
60
+
61
+ if dataset_name == 'niah' :
62
+ self .tokenized_prompts , self .gt , self .ctx_len , self .depth_pct = self .get_dataset ()
63
+ else :
64
+ self .tokenized_prompts , self .gt = self .get_dataset ()
65
+
66
+ self .num_samples = len (self .tokenized_prompts )
67
+ self .gen_len = self .get_gen_len ()
68
+ self .metric = self .get_metric ()
69
+
70
+ def __str__ (self ) -> str :
71
+ return f"Dataset: { self .dataset_name } , Num Samples: { self .num_samples } , Gen Len: { self .gen_len } , DataLen: { self .datalen } "
72
+
73
+ def __repr__ (self ) -> str :
74
+ return f"Dataset: { self .dataset_name } , Num Samples: { self .num_samples } , Gen Len: { self .gen_len } , DataLen: { self .datalen } "
75
+
76
+ def __len__ (self ) -> int :
77
+ return self .num_samples
78
+
79
+ def shard (self , rank , world_size ):
80
+ if world_size > 1 :
81
+ shard_size = self .num_samples // world_size
82
+ start = rank * shard_size
83
+ end = start + shard_size if rank != world_size - 1 else self .num_samples
84
+ shard_tokenized_prompts , shard_gt = self .tokenized_prompts [start :end ], self .gt [start :end ]
85
+ self .tokenized_prompts = shard_tokenized_prompts
86
+ self .gt = shard_gt
87
+ self .num_samples = len (shard_tokenized_prompts )
88
+
89
+ self .is_sharded = True
90
+
91
+ def get_gen_len (self ):
92
+ if 'niah' == self .dataset_name :
93
+ return 10
94
+ elif 'niah' in self .dataset_name :
95
+ return 128
96
+ elif 'vt' in self .dataset_name :
97
+ return 30
98
+ elif 'cwe' in self .dataset_name :
99
+ return 120
100
+ elif 'fwe' in self .dataset_name :
101
+ return 50
102
+ elif 'qa' in self .dataset_name :
103
+ return 32
104
+ else :
105
+ raise Exception ("Gen len not found" )
106
+
107
+ def __getitem__ (self , idx ):
108
+ if 'persona' in self .dataset_name :
109
+ return self .tokenized_prompts [idx ], self .queries [idx ], self .gt [idx ]
110
+ return self .tokenized_prompts [idx ], self .gt [idx ]
111
+
112
+ def get_metric (self ):
113
+ if 'multiquery' in self .dataset_name or 'multivalue' in self .dataset_name :
114
+ return METRICS_FN ['multi' ]
115
+ elif 'niah' in self .dataset_name :
116
+ return METRICS_FN ['niah' ]
117
+ elif 'vt' in self .dataset_name :
118
+ return METRICS_FN ['vt' ]
119
+ elif 'cwe' in self .dataset_name :
120
+ return METRICS_FN ['cwe' ]
121
+ elif 'fwe' in self .dataset_name :
122
+ return METRICS_FN ['fwe' ]
123
+ elif 'qa' in self .dataset_name :
124
+ return METRICS_FN ['qa' ]
125
+ else :
126
+ raise Exception ("Metric not found" )
127
+
128
+ def get_dataset (self ):
129
+ if 'ruler' in self .dataset_name : # ruler/xxx
130
+ task = self .dataset_name .split ('/' )[- 1 ]
131
+ assert self .datalen in [8 * 1024 , 16 * 1024 , 32 * 1024 , 64 * 1024 , 128 * 1024 , 256 * 1024 ], "Only support datalen of 16k, 32k, 64k, 128k"
132
+
133
+ if 'llama-3' in self .tokenizer .name_or_path .lower ():
134
+ model_dir = 'llama-3'
135
+ elif 'yi' in self .tokenizer .name_or_path .lower ():
136
+ model_dir = 'yi'
137
+ elif 'lwm' in self .tokenizer .name_or_path .lower ():
138
+ model_dir = 'lwm'
139
+ elif 'glm' in self .tokenizer .name_or_path .lower ():
140
+ model_dir = 'glm'
141
+ elif 'qwen' in self .tokenizer .name_or_path .lower ():
142
+ model_dir = 'qwen'
143
+ elif 'phi' in self .tokenizer .name_or_path .lower ():
144
+ model_dir = 'phi'
145
+ else :
146
+ raise Exception ("Model not found" , self .tokenizer .name_or_path )
147
+
148
+ dataset = load_dataset ("json" , data_files = f'{ DATADIR ["ruler" ]} /{ model_dir } /{ self .datalen } /{ task } /validation.jsonl' , split = 'train' )
149
+ if self .num_samples > 0 :
150
+ self .num_samples = min (self .num_samples , len (dataset ))
151
+ else :
152
+ self .num_samples = len (dataset )
153
+ tokenized_prompts = []
154
+ gt = []
155
+
156
+ for i in range (self .num_samples ):
157
+ input_text = dataset [i ]['input' ]
158
+ input_ids = self .tokenizer .encode (input_text , return_tensors = "pt" , add_special_tokens = False )
159
+ tokenized_prompts .append (input_ids )
160
+ gt .append (dataset [i ]['outputs' ])
161
+
162
+ return tokenized_prompts , gt
163
+
164
+ elif self .dataset_name == 'niah' :
165
+ print (colored (f"[Warning] NIAH dataset cannot set # samples, it is up to world_size, which is set to { self .world_size } " , 'red' ))
166
+
167
+ haystack_file = f'{ DATADIR ["niah" ]} /pg19_mini.jsonl'
168
+ context_lengths_min = 16 * 1024
169
+ context_lengths_max = self .datalen
170
+ n_context_length_intervals = 15
171
+ n_document_depth_intervals = 10 # position of the needle in the haystack
172
+ n_rounds = 1 # max(1, 4 // self.world_size) # 8 rounds in total assume we have 8xGPUs
173
+ needle = "\n The special magic {city} number is: {rnd_number}\n "
174
+ retrieval_question = "What is the special magic {} number?"
175
+ rnd_number_digits = 7
176
+
177
+ context_lengths = np .round (
178
+ np .linspace (
179
+ context_lengths_min ,
180
+ context_lengths_max ,
181
+ num = n_context_length_intervals ,
182
+ endpoint = True ,
183
+ )
184
+ ).astype (int )
185
+
186
+ document_depth_percents = np .round ( # we use linear scale here
187
+ np .linspace (
188
+ 0 ,
189
+ 100 ,
190
+ num = n_document_depth_intervals ,
191
+ endpoint = True ,
192
+ )
193
+ ).astype (int )
194
+
195
+ self .is_sharded = True # we shard the data during init dataset
196
+
197
+ full_contexts = read_context_files (n = n_rounds , context_lengths = context_lengths , haystack_file = haystack_file , tokenizer = self .tokenizer )
198
+ full_tokens = [
199
+ self .tokenizer .encode (full_context , add_special_tokens = False ) for full_context in full_contexts
200
+ ]
201
+
202
+ tokenized_prompts = []
203
+ gt = []
204
+ ctx_len = []
205
+ depth_pct = []
206
+
207
+ for context_length in context_lengths :
208
+ trim_contexts = [
209
+ self .tokenizer .decode (full_token [:context_length ], skip_special_tokens = True )
210
+ for full_token in full_tokens
211
+ ]
212
+ contexts = []
213
+ for depth_percent in document_depth_percents :
214
+ for i in range (n_rounds ):
215
+ random_city = random .choice (RANDOM_NEEDLE_CITIES )
216
+ insert_needle = True
217
+ needle_rnd_number = str (generate_random_number (rnd_number_digits ))
218
+ context = create_contexts (
219
+ needle_rnd_number = needle_rnd_number ,
220
+ insert_needle = insert_needle ,
221
+ random_city = random_city ,
222
+ trim_context = trim_contexts [i ],
223
+ context_length = context_length ,
224
+ depth_percent = depth_percent ,
225
+ needle = needle ,
226
+ retrieval_question = retrieval_question ,
227
+ tokenizer = self .tokenizer ,
228
+ final_context_length_buffer = 32 ,
229
+ )
230
+ contexts .append (context )
231
+
232
+ for context in contexts :
233
+ prompt = NIAH_TEMPLATE .format (
234
+ context = context ["context" ], question = context ["question" ]
235
+ )
236
+ input_tensor = self .tokenizer (prompt , return_tensors = "pt" , return_attention_mask = False )
237
+ tokenized_prompts .append (input_tensor .input_ids )
238
+ gt .append (context ["needle_rnd_number" ])
239
+ ctx_len .append (context ["context_length" ])
240
+ depth_pct .append (context ["depth_percent" ])
241
+
242
+ return tokenized_prompts , gt , ctx_len , depth_pct
243
+
244
+ else :
245
+ raise ValueError (f"Dataset { self .dataset_name } not found, please choose in ruler, persona, infini_bench, needle, niah, long_bench" )
0 commit comments