forked from valhongli/reID-PCB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
88 lines (70 loc) · 2.77 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import re
from PIL import Image
import torch
import torch.utils.data as data
from utils import log
class Market1501(data.Dataset):
base_folder = 'Market-1501-v15.09.15'
train_folder = 'test'
train_folder = 'bounding_box_train'
test_folder = 'bounding_box_test'
query_folder = 'query'
def __init__(self, root, data_type='train',
transform=None, target_transform=None,
download=False, once=False):
self.root = root
self.data_type = data_type
self.transform = transform
self.target_transform = target_transform
self.once = once
if download:
self.download()
if self.data_type == 'train':
self.folder = os.path.join(
self.root, self.base_folder, self.train_folder)
elif self.data_type == 'test':
self.folder = os.path.join(
self.root, self.base_folder, self.test_folder)
else:
self.folder = os.path.join(
self.root, self.base_folder, self.query_folder)
self.pattern = re.compile(r'^(\-1|\d{4})_c(\d)s\d_\d{6}_\d{2}.*\.jpg$')
self.file_list = list(filter(self.pattern.search, os.listdir(self.folder)))
if self.once:
self.load_data_at_once()
def load_data_at_once(self):
self.data, self.labels, self.cameras = [], [], []
k, total = 0, len(self.file_list)
for file in self.file_list:
img, label, camera = self.load_image(file)
self.data.append(img)
self.labels.append(label)
self.cameras.append(camera)
if k % 500 == 499:
log('[%s_data_loading] %5d/%5d' % (self.data_type, k, total))
k += 1
self.data = torch.cat(self.data, 0)
self.data = self.data.view(-1, 3, 768, 256)
def __getitem__(self, index):
if self.once:
return self.data[index], self.labels[index], self.cameras[index]
return self.load_image(self.file_list[index])
def load_image(self, filename):
label, camera = re.findall(self.pattern, filename)[0]
label, camera = int(label), int(camera)
img_filename = os.path.join(self.folder, filename)
img = Image.open(img_filename)
img.load()
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
img = self.target_transform(img)
return img, label, camera, img_filename
def __len__(self):
return len(self.file_list)
def download(self):
pass
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
return fmt_str