forked from yhlleo/cifar10Dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pickled.py
51 lines (45 loc) · 1.48 KB
/
pickled.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: pickled.py
# Author: Yahui Liu <[email protected]>
import os
import pickle, cPickle
BIN_COUNTS = 5
def pickled(savepath, data, label, fnames, bin_num=BIN_COUNTS, mode="train"):
'''
savepath (str): save path
data (array): image data, a nx3072 array
label (list): image label, a list with length n
fnames (str list): image names, a list with length n
bin_num (int): save data in several files
mode (str): {'train', 'test'}
'''
assert os.path.isdir(savepath)
total_num = len(fnames)
samples_per_bin = total_num / bin_num
assert samples_per_bin > 0
idx = 0
for i in range(bin_num):
start = i*samples_per_bin
end = (i+1)*samples_per_bin
if end <= total_num:
dict = {'data': data[start:end, :],
'labels': label[start:end],
'filenames': fnames[start:end]}
else:
dict = {'data': data[start:, :],
'labels': label[start:],
'filenames': fnames[start:]}
if mode == "train":
dict['batch_label'] = "training batch {} of {}".format(idx, bin_num)
else:
dict['batch_label'] = "testing batch {} of {}".format(idx, bin_num)
with open(os.path.join(savepath, 'data_batch_'+str(idx)), 'wb') as fi:
cPickle.dump(dict, fi)
idx = idx + 1
def unpickled(filename):
#assert os.path.isdir(filename)
assert os.path.isfile(filename)
with open(filename, 'rb') as fo:
dict = cPickle.load(fo)
return dict