This repository has been archived by the owner on May 28, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
liubofang
committed
Jun 5, 2018
0 parents
commit 1121099
Showing
13 changed files
with
927 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.swp | ||
*.pyc | ||
__pycache__ | ||
coco |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
import numpy as np | ||
|
||
import torch | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import Dataset | ||
|
||
from PIL import Image | ||
from skimage.transform import resize | ||
|
||
class COCODataset(Dataset): | ||
def __init__(self, list_path, img_size=416): | ||
with open(list_path, 'r') as file: | ||
self.img_files = file.readlines() | ||
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt' | ||
).replace('.jpg', '.txt') for path in self.img_files] | ||
self.img_shape = (img_size, img_size) | ||
self.max_objects = 50 | ||
|
||
def __getitem__(self, index): | ||
img_path = self.img_files[index % len(self.img_files)].rstrip() | ||
img = np.array(Image.open(img_path)) | ||
|
||
# Black and white images | ||
if len(img.shape) == 2: | ||
img = np.repeat(img[:, :, np.newaxis], 3, axis=2) | ||
|
||
h, w, _ = img.shape | ||
dim_diff = np.abs(h - w) | ||
# Upper (left) and lower (right) padding | ||
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2 | ||
# Determine padding | ||
pad = ((pad1, pad2), (0, 0), (0, 0)) if h <= w else ((0, 0), (pad1, pad2), (0, 0)) | ||
# Add padding | ||
input_img = np.pad(img, pad, 'constant', constant_values=128) / 255. | ||
padded_h, padded_w, _ = input_img.shape | ||
# Resize and normalize | ||
input_img = resize(input_img, (*self.img_shape, 3), mode='reflect') | ||
# Channels-first | ||
input_img = np.transpose(input_img, (2, 0, 1)) | ||
# As pytorch tensor | ||
input_img = torch.from_numpy(input_img).float() | ||
|
||
#--------- | ||
# Label | ||
#--------- | ||
|
||
label_path = self.label_files[index % len(self.img_files)].rstrip() | ||
|
||
labels = None | ||
if os.path.exists(label_path): | ||
labels = np.loadtxt(label_path).reshape(-1, 5) | ||
# Extract coordinates for unpadded + unscaled image | ||
x1 = w * (labels[:, 1] - labels[:, 3]/2) | ||
y1 = h * (labels[:, 2] - labels[:, 4]/2) | ||
x2 = w * (labels[:, 1] + labels[:, 3]/2) | ||
y2 = h * (labels[:, 2] + labels[:, 4]/2) | ||
# Adjust for added padding | ||
x1 += pad[1][0] | ||
y1 += pad[0][0] | ||
x2 += pad[1][0] | ||
y2 += pad[0][0] | ||
# Calculate ratios from coordinates | ||
labels[:, 1] = ((x1 + x2) / 2) / padded_w | ||
labels[:, 2] = ((y1 + y2) / 2) / padded_h | ||
labels[:, 3] *= w / padded_w | ||
labels[:, 4] *= h / padded_h | ||
# Fill matrix | ||
filled_labels = np.zeros((self.max_objects, 5)) | ||
if labels is not None: | ||
filled_labels[range(len(labels))[:self.max_objects]] = labels[:self.max_objects] | ||
filled_labels = torch.from_numpy(filled_labels) | ||
|
||
return img_path, input_img, filled_labels | ||
|
||
def __len__(self): | ||
return len(self.img_files) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from __future__ import division | ||
import math | ||
import time | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
|
||
def load_classes(path): | ||
""" | ||
Loads class labels at 'path' | ||
""" | ||
fp = open(path, "r") | ||
names = fp.read().split("\n")[:-1] | ||
return names | ||
|
||
def weights_init_normal(m): | ||
classname = m.__class__.__name__ | ||
if classname.find('Conv') != -1: | ||
torch.nn.init.normal_(m.weight.data, 0.0, 0.02) | ||
elif classname.find('BatchNorm2d') != -1: | ||
torch.nn.init.normal_(m.weight.data, 1.0, 0.02) | ||
torch.nn.init.constant_(m.bias.data, 0.0) | ||
|
||
def bbox_iou(box1, box2, x1y1x2y2=True): | ||
""" | ||
Returns the IoU of two bounding boxes | ||
""" | ||
if not x1y1x2y2: | ||
# Transform from center and width to exact coordinates | ||
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 | ||
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 | ||
b2_x1, b2_x2 = box2[:, 0] - box1[:, 2] / 2, box2[:, 0] + box1[:, 2] / 2 | ||
b2_y1, b2_y2 = box2[:, 1] - box1[:, 3] / 2, box2[:, 1] + box1[:, 3] / 2 | ||
else: | ||
# Get the coordinates of bounding boxes | ||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] | ||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] | ||
# get the corrdinates of the intersection rectangle | ||
inter_rect_x1 = torch.max(b1_x1, b2_x1) | ||
inter_rect_y1 = torch.max(b1_y1, b2_y1) | ||
inter_rect_x2 = torch.min(b1_x2, b2_x2) | ||
inter_rect_y2 = torch.min(b1_y2, b2_y2) | ||
# Intersection area | ||
inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * \ | ||
torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0) | ||
# Union Area | ||
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) | ||
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) | ||
|
||
iou = inter_area / (b1_area + b2_area - inter_area) | ||
|
||
return iou | ||
|
||
|
||
def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): | ||
""" | ||
Removes detections with lower object confidence score than 'conf_thres' and performs | ||
Non-Maximum Suppression to further filter detections. | ||
Returns detections with shape: | ||
(x1, y1, x2, y2, object_conf, class_score, class_pred) | ||
""" | ||
|
||
# From (center x, center y, width, height) to (x1, y1, x2, y2) | ||
box_corner = prediction.new(prediction.shape) | ||
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 | ||
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 | ||
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 | ||
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 | ||
prediction[:, :, :4] = box_corner[:, :, :4] | ||
|
||
output = [None for _ in range(len(prediction))] | ||
for image_i, image_pred in enumerate(prediction): | ||
# Filter out confidence scores below threshold | ||
conf_mask = (image_pred[:, 4] >= conf_thres).squeeze() | ||
image_pred = image_pred[conf_mask] | ||
# If none are remaining => process next image | ||
if not image_pred.size(0): | ||
continue | ||
# Get score and class with highest confidence | ||
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) | ||
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) | ||
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1) | ||
# Iterate through all predicted classes | ||
unique_labels = detections[:, -1].cpu().unique() | ||
if prediction.is_cuda: | ||
unique_labels = unique_labels.cuda() | ||
for c in unique_labels: | ||
# Get the detections with the particular class | ||
detections_class = detections[detections[:, -1] == c] | ||
# Sort the detections by maximum objectness confidence | ||
_, conf_sort_index = torch.sort(detections_class[:, 4], descending=True) | ||
detections_class = detections_class[conf_sort_index] | ||
# Perform non-maximum suppression | ||
max_detections = [] | ||
while detections_class.size(0): | ||
# Get detection with highest confidence and save as max detection | ||
max_detections.append(detections_class[0].unsqueeze(0)) | ||
# Stop if we're at the last detection | ||
if len(detections_class) == 1: | ||
break | ||
# Get the IOUs for all boxes with lower confidence | ||
ious = bbox_iou(max_detections[-1], detections_class[1:]) | ||
# Remove detections with IoU >= NMS threshold | ||
detections_class = detections_class[1:][ious < nms_thres] | ||
|
||
max_detections = torch.cat(max_detections).data | ||
# Add max detections to outputs | ||
output[image_i] = max_detections if output[image_i] is None else torch.cat((output[image_i], max_detections)) | ||
|
||
return output | ||
|
||
def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, dim, ignore_thres): | ||
nB = target.size(0) | ||
nA = num_anchors | ||
nC = num_classes | ||
dim = dim | ||
anchor_step = len(anchors)/num_anchors | ||
conf_mask = torch.ones(nB, nA, dim, dim) | ||
coord_mask = torch.zeros(nB, nA, dim, dim) | ||
cls_mask = torch.zeros(nB, nA, dim, dim) | ||
tx = torch.zeros(nB, nA, dim, dim) | ||
ty = torch.zeros(nB, nA, dim, dim) | ||
tw = torch.zeros(nB, nA, dim, dim) | ||
th = torch.zeros(nB, nA, dim, dim) | ||
tconf = torch.zeros(nB, nA, dim, dim) | ||
tcls = torch.zeros(nB, nA, dim, dim, num_classes) | ||
|
||
for b in range(nB): | ||
# Get sample predictions | ||
cur_pred_boxes = pred_boxes[b].view(-1, 4) | ||
cur_ious = torch.zeros(cur_pred_boxes.size(0)) | ||
for t in range(target.shape[1]): | ||
if target[b, t, 1] == 0: | ||
break | ||
# Convert to position relative to box | ||
gx = target[b, t, 1] * dim | ||
gy = target[b, t, 2] * dim | ||
gw = target[b, t, 3] * dim | ||
gh = target[b, t, 4] * dim | ||
cur_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0) | ||
cur_ious = torch.max(cur_ious, bbox_iou(cur_pred_boxes.data, cur_gt_boxes.data, x1y1x2y2=False)) | ||
# Objects with highest confidence than threshold are set to zero | ||
conf_mask[b][cur_ious.view_as(conf_mask[b]) > ignore_thres] = 0 | ||
|
||
nGT = 0 | ||
nCorrect = 0 | ||
for b in range(nB): | ||
for t in range(target.shape[1]): | ||
if target[b, t].sum() == 0: | ||
continue | ||
nGT = nGT + 1 | ||
# Convert to position relative to box | ||
gx = target[b, t, 1] * dim | ||
gy = target[b, t, 2] * dim | ||
gw = target[b, t, 3] * dim | ||
gh = target[b, t, 4] * dim | ||
# Get grid box indices | ||
gi = int(gx) | ||
gj = int(gy) | ||
# Get shape of gt box | ||
gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0) | ||
# Get shape of anchor box | ||
anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((len(anchors), 2)), np.array(anchors)), 1)) | ||
# Calculate iou between gt and anchor shape | ||
anch_ious = bbox_iou(gt_box, anchor_shapes) | ||
# Find the best matching anchor box | ||
best_n = np.argmax(anch_ious) | ||
best_iou = anch_ious[best_n] | ||
# Get the ground truth box and corresponding best prediction | ||
gt_box = torch.FloatTensor(np.array([gx, gy, gw, gh])).unsqueeze(0) | ||
pred_box = pred_boxes[b, best_n, gj, gi].unsqueeze(0) | ||
|
||
# Masks | ||
coord_mask[b][best_n][gj][gi] = 1 | ||
cls_mask[b][best_n][gj][gi] = 1 | ||
conf_mask[b][best_n][gj][gi] = 1 | ||
# Coordinates | ||
tx[b][best_n][gj][gi] = gx - gi | ||
ty[b][best_n][gj][gi] = gy - gj | ||
# Width and height | ||
tw[b][best_n][gj][gi] = math.log(gw/anchors[best_n][0] + 1e-8) | ||
th[b][best_n][gj][gi] = math.log(gh/anchors[best_n][1] + 1e-8) | ||
# Calculate iou between ground truth and best matching prediction | ||
iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False) | ||
tconf[b][best_n][gj][gi] = iou | ||
tcls[b][best_n][gj][gi] = to_categorical(int(target[b, t, 0]), num_classes) | ||
|
||
if iou > 0.5: | ||
nCorrect = nCorrect + 1 | ||
|
||
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls | ||
|
||
def to_categorical(y, num_classes): | ||
""" 1-hot encodes a tensor """ | ||
return torch.from_numpy(np.eye(num_classes, dtype='uint8')[y]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
person | ||
bicycle | ||
car | ||
motorbike | ||
aeroplane | ||
bus | ||
train | ||
truck | ||
boat | ||
traffic light | ||
fire hydrant | ||
stop sign | ||
parking meter | ||
bench | ||
bird | ||
cat | ||
dog | ||
horse | ||
sheep | ||
cow | ||
elephant | ||
bear | ||
zebra | ||
giraffe | ||
backpack | ||
umbrella | ||
handbag | ||
tie | ||
suitcase | ||
frisbee | ||
skis | ||
snowboard | ||
sports ball | ||
kite | ||
baseball bat | ||
baseball glove | ||
skateboard | ||
surfboard | ||
tennis racket | ||
bottle | ||
wine glass | ||
cup | ||
fork | ||
knife | ||
spoon | ||
bowl | ||
banana | ||
apple | ||
sandwich | ||
orange | ||
broccoli | ||
carrot | ||
hot dog | ||
pizza | ||
donut | ||
cake | ||
chair | ||
sofa | ||
pottedplant | ||
bed | ||
diningtable | ||
toilet | ||
tvmonitor | ||
laptop | ||
mouse | ||
remote | ||
keyboard | ||
cell phone | ||
microwave | ||
oven | ||
toaster | ||
sink | ||
refrigerator | ||
book | ||
clock | ||
vase | ||
scissors | ||
teddy bear | ||
hair drier | ||
toothbrush |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
|
||
# CREDIT: https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh | ||
|
||
# Clone COCO API | ||
git clone https://github.com/pdollar/coco | ||
cd coco | ||
|
||
mkdir images | ||
cd images | ||
|
||
# Download Images | ||
wget -c https://pjreddie.com/media/files/train2014.zip | ||
wget -c https://pjreddie.com/media/files/val2014.zip | ||
|
||
# Unzip | ||
unzip -q train2014.zip | ||
unzip -q val2014.zip | ||
|
||
cd .. | ||
|
||
# Download COCO Metadata | ||
wget -c https://pjreddie.com/media/files/instances_train-val2014.zip | ||
wget -c https://pjreddie.com/media/files/coco/5k.part | ||
wget -c https://pjreddie.com/media/files/coco/trainvalno5k.part | ||
wget -c https://pjreddie.com/media/files/coco/labels.tgz | ||
tar xzf labels.tgz | ||
unzip -q instances_train-val2014.zip | ||
|
||
# Set Up Image Lists | ||
paste <(awk "{print \"$PWD\"}" <5k.part) 5k.part | tr -d '\t' > 5k.txt | ||
paste <(awk "{print \"$PWD\"}" <trainvalno5k.part) trainvalno5k.part | tr -d '\t' > trainvalno5k.txt |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from . import darknet | ||
|
||
backbone_fn = { | ||
"darknet_21": darknet.darknet21, | ||
"darknet_53": darknet.darknet53, | ||
} |
Oops, something went wrong.