From 11210990bffd3f66cf05023b0ddbe8de3d654dac Mon Sep 17 00:00:00 2001 From: liubofang Date: Tue, 5 Jun 2018 14:59:50 +0800 Subject: [PATCH] [feature]First commit --- .gitignore | 4 + common/__init__.py | 0 common/coco_dataset.py | 77 +++++++++++++ common/utils.py | 195 +++++++++++++++++++++++++++++++++ data/coco.names | 80 ++++++++++++++ data/get_coco_dataset.sh | 32 ++++++ nets/__init__.py | 0 nets/backbone/__init__.py | 6 ++ nets/backbone/darknet.py | 94 ++++++++++++++++ nets/model_main.py | 79 ++++++++++++++ nets/yolo_loss.py | 103 ++++++++++++++++++ training/params.py | 35 ++++++ training/training.py | 222 ++++++++++++++++++++++++++++++++++++++ 13 files changed, 927 insertions(+) create mode 100644 .gitignore create mode 100644 common/__init__.py create mode 100644 common/coco_dataset.py create mode 100644 common/utils.py create mode 100644 data/coco.names create mode 100644 data/get_coco_dataset.sh create mode 100644 nets/__init__.py create mode 100644 nets/backbone/__init__.py create mode 100644 nets/backbone/darknet.py create mode 100644 nets/model_main.py create mode 100644 nets/yolo_loss.py create mode 100644 training/params.py create mode 100644 training/training.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ee85538 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.swp +*.pyc +__pycache__ +coco diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/common/coco_dataset.py b/common/coco_dataset.py new file mode 100644 index 0000000..befb66f --- /dev/null +++ b/common/coco_dataset.py @@ -0,0 +1,77 @@ +import os +import numpy as np + +import torch +import torchvision.transforms as transforms +from torch.utils.data import Dataset + +from PIL import Image +from skimage.transform import resize + +class COCODataset(Dataset): + def __init__(self, list_path, img_size=416): + with open(list_path, 'r') as file: + self.img_files = file.readlines() + self.label_files = [path.replace('images', 'labels').replace('.png', '.txt' + ).replace('.jpg', '.txt') for path in self.img_files] + self.img_shape = (img_size, img_size) + self.max_objects = 50 + + def __getitem__(self, index): + img_path = self.img_files[index % len(self.img_files)].rstrip() + img = np.array(Image.open(img_path)) + + # Black and white images + if len(img.shape) == 2: + img = np.repeat(img[:, :, np.newaxis], 3, axis=2) + + h, w, _ = img.shape + dim_diff = np.abs(h - w) + # Upper (left) and lower (right) padding + pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2 + # Determine padding + pad = ((pad1, pad2), (0, 0), (0, 0)) if h <= w else ((0, 0), (pad1, pad2), (0, 0)) + # Add padding + input_img = np.pad(img, pad, 'constant', constant_values=128) / 255. + padded_h, padded_w, _ = input_img.shape + # Resize and normalize + input_img = resize(input_img, (*self.img_shape, 3), mode='reflect') + # Channels-first + input_img = np.transpose(input_img, (2, 0, 1)) + # As pytorch tensor + input_img = torch.from_numpy(input_img).float() + + #--------- + # Label + #--------- + + label_path = self.label_files[index % len(self.img_files)].rstrip() + + labels = None + if os.path.exists(label_path): + labels = np.loadtxt(label_path).reshape(-1, 5) + # Extract coordinates for unpadded + unscaled image + x1 = w * (labels[:, 1] - labels[:, 3]/2) + y1 = h * (labels[:, 2] - labels[:, 4]/2) + x2 = w * (labels[:, 1] + labels[:, 3]/2) + y2 = h * (labels[:, 2] + labels[:, 4]/2) + # Adjust for added padding + x1 += pad[1][0] + y1 += pad[0][0] + x2 += pad[1][0] + y2 += pad[0][0] + # Calculate ratios from coordinates + labels[:, 1] = ((x1 + x2) / 2) / padded_w + labels[:, 2] = ((y1 + y2) / 2) / padded_h + labels[:, 3] *= w / padded_w + labels[:, 4] *= h / padded_h + # Fill matrix + filled_labels = np.zeros((self.max_objects, 5)) + if labels is not None: + filled_labels[range(len(labels))[:self.max_objects]] = labels[:self.max_objects] + filled_labels = torch.from_numpy(filled_labels) + + return img_path, input_img, filled_labels + + def __len__(self): + return len(self.img_files) diff --git a/common/utils.py b/common/utils.py new file mode 100644 index 0000000..1327c3f --- /dev/null +++ b/common/utils.py @@ -0,0 +1,195 @@ +from __future__ import division +import math +import time +import torch +import torch.nn as nn +import numpy as np + +def load_classes(path): + """ + Loads class labels at 'path' + """ + fp = open(path, "r") + names = fp.read().split("\n")[:-1] + return names + +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + +def bbox_iou(box1, box2, x1y1x2y2=True): + """ + Returns the IoU of two bounding boxes + """ + if not x1y1x2y2: + # Transform from center and width to exact coordinates + b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 + b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 + b2_x1, b2_x2 = box2[:, 0] - box1[:, 2] / 2, box2[:, 0] + box1[:, 2] / 2 + b2_y1, b2_y2 = box2[:, 1] - box1[:, 3] / 2, box2[:, 1] + box1[:, 3] / 2 + else: + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] + # get the corrdinates of the intersection rectangle + inter_rect_x1 = torch.max(b1_x1, b2_x1) + inter_rect_y1 = torch.max(b1_y1, b2_y1) + inter_rect_x2 = torch.min(b1_x2, b2_x2) + inter_rect_y2 = torch.min(b1_y2, b2_y2) + # Intersection area + inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * \ + torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0) + # Union Area + b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) + b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) + + iou = inter_area / (b1_area + b2_area - inter_area) + + return iou + + +def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): + """ + Removes detections with lower object confidence score than 'conf_thres' and performs + Non-Maximum Suppression to further filter detections. + Returns detections with shape: + (x1, y1, x2, y2, object_conf, class_score, class_pred) + """ + + # From (center x, center y, width, height) to (x1, y1, x2, y2) + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + + output = [None for _ in range(len(prediction))] + for image_i, image_pred in enumerate(prediction): + # Filter out confidence scores below threshold + conf_mask = (image_pred[:, 4] >= conf_thres).squeeze() + image_pred = image_pred[conf_mask] + # If none are remaining => process next image + if not image_pred.size(0): + continue + # Get score and class with highest confidence + class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) + # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) + detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1) + # Iterate through all predicted classes + unique_labels = detections[:, -1].cpu().unique() + if prediction.is_cuda: + unique_labels = unique_labels.cuda() + for c in unique_labels: + # Get the detections with the particular class + detections_class = detections[detections[:, -1] == c] + # Sort the detections by maximum objectness confidence + _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True) + detections_class = detections_class[conf_sort_index] + # Perform non-maximum suppression + max_detections = [] + while detections_class.size(0): + # Get detection with highest confidence and save as max detection + max_detections.append(detections_class[0].unsqueeze(0)) + # Stop if we're at the last detection + if len(detections_class) == 1: + break + # Get the IOUs for all boxes with lower confidence + ious = bbox_iou(max_detections[-1], detections_class[1:]) + # Remove detections with IoU >= NMS threshold + detections_class = detections_class[1:][ious < nms_thres] + + max_detections = torch.cat(max_detections).data + # Add max detections to outputs + output[image_i] = max_detections if output[image_i] is None else torch.cat((output[image_i], max_detections)) + + return output + +def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, dim, ignore_thres): + nB = target.size(0) + nA = num_anchors + nC = num_classes + dim = dim + anchor_step = len(anchors)/num_anchors + conf_mask = torch.ones(nB, nA, dim, dim) + coord_mask = torch.zeros(nB, nA, dim, dim) + cls_mask = torch.zeros(nB, nA, dim, dim) + tx = torch.zeros(nB, nA, dim, dim) + ty = torch.zeros(nB, nA, dim, dim) + tw = torch.zeros(nB, nA, dim, dim) + th = torch.zeros(nB, nA, dim, dim) + tconf = torch.zeros(nB, nA, dim, dim) + tcls = torch.zeros(nB, nA, dim, dim, num_classes) + + for b in range(nB): + # Get sample predictions + cur_pred_boxes = pred_boxes[b].view(-1, 4) + cur_ious = torch.zeros(cur_pred_boxes.size(0)) + for t in range(target.shape[1]): + if target[b, t, 1] == 0: + break + # Convert to position relative to box + gx = target[b, t, 1] * dim + gy = target[b, t, 2] * dim + gw = target[b, t, 3] * dim + gh = target[b, t, 4] * dim + cur_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0) + cur_ious = torch.max(cur_ious, bbox_iou(cur_pred_boxes.data, cur_gt_boxes.data, x1y1x2y2=False)) + # Objects with highest confidence than threshold are set to zero + conf_mask[b][cur_ious.view_as(conf_mask[b]) > ignore_thres] = 0 + + nGT = 0 + nCorrect = 0 + for b in range(nB): + for t in range(target.shape[1]): + if target[b, t].sum() == 0: + continue + nGT = nGT + 1 + # Convert to position relative to box + gx = target[b, t, 1] * dim + gy = target[b, t, 2] * dim + gw = target[b, t, 3] * dim + gh = target[b, t, 4] * dim + # Get grid box indices + gi = int(gx) + gj = int(gy) + # Get shape of gt box + gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0) + # Get shape of anchor box + anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((len(anchors), 2)), np.array(anchors)), 1)) + # Calculate iou between gt and anchor shape + anch_ious = bbox_iou(gt_box, anchor_shapes) + # Find the best matching anchor box + best_n = np.argmax(anch_ious) + best_iou = anch_ious[best_n] + # Get the ground truth box and corresponding best prediction + gt_box = torch.FloatTensor(np.array([gx, gy, gw, gh])).unsqueeze(0) + pred_box = pred_boxes[b, best_n, gj, gi].unsqueeze(0) + + # Masks + coord_mask[b][best_n][gj][gi] = 1 + cls_mask[b][best_n][gj][gi] = 1 + conf_mask[b][best_n][gj][gi] = 1 + # Coordinates + tx[b][best_n][gj][gi] = gx - gi + ty[b][best_n][gj][gi] = gy - gj + # Width and height + tw[b][best_n][gj][gi] = math.log(gw/anchors[best_n][0] + 1e-8) + th[b][best_n][gj][gi] = math.log(gh/anchors[best_n][1] + 1e-8) + # Calculate iou between ground truth and best matching prediction + iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False) + tconf[b][best_n][gj][gi] = iou + tcls[b][best_n][gj][gi] = to_categorical(int(target[b, t, 0]), num_classes) + + if iou > 0.5: + nCorrect = nCorrect + 1 + + return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls + +def to_categorical(y, num_classes): + """ 1-hot encodes a tensor """ + return torch.from_numpy(np.eye(num_classes, dtype='uint8')[y]) diff --git a/data/coco.names b/data/coco.names new file mode 100644 index 0000000..ca76c80 --- /dev/null +++ b/data/coco.names @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/data/get_coco_dataset.sh b/data/get_coco_dataset.sh new file mode 100644 index 0000000..81b0017 --- /dev/null +++ b/data/get_coco_dataset.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# CREDIT: https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh + +# Clone COCO API +git clone https://github.com/pdollar/coco +cd coco + +mkdir images +cd images + +# Download Images +wget -c https://pjreddie.com/media/files/train2014.zip +wget -c https://pjreddie.com/media/files/val2014.zip + +# Unzip +unzip -q train2014.zip +unzip -q val2014.zip + +cd .. + +# Download COCO Metadata +wget -c https://pjreddie.com/media/files/instances_train-val2014.zip +wget -c https://pjreddie.com/media/files/coco/5k.part +wget -c https://pjreddie.com/media/files/coco/trainvalno5k.part +wget -c https://pjreddie.com/media/files/coco/labels.tgz +tar xzf labels.tgz +unzip -q instances_train-val2014.zip + +# Set Up Image Lists +paste <(awk "{print \"$PWD\"}" <5k.part) 5k.part | tr -d '\t' > 5k.txt +paste <(awk "{print \"$PWD\"}" trainvalno5k.txt diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nets/backbone/__init__.py b/nets/backbone/__init__.py new file mode 100644 index 0000000..b2626ff --- /dev/null +++ b/nets/backbone/__init__.py @@ -0,0 +1,6 @@ +from . import darknet + +backbone_fn = { + "darknet_21": darknet.darknet21, + "darknet_53": darknet.darknet53, +} diff --git a/nets/backbone/darknet.py b/nets/backbone/darknet.py new file mode 100644 index 0000000..9913a82 --- /dev/null +++ b/nets/backbone/darknet.py @@ -0,0 +1,94 @@ +import torch.nn as nn +import math + +__all__ = ['darknet21', 'darknet53'] + + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, + stride=1, padding=0, bias=False) + self.bn1 = nn.BatchNorm2d(planes[0]) + self.relu1 = nn.LeakyReLU(0.1) + self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes[1]) + self.relu2 = nn.LeakyReLU(0.1) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu2(out) + + out += residual + return out + + +class DarkNet(nn.Module): + def __init__(self, layers): + super(DarkNet, self).__init__() + self.inplanes = 32 + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu1 = nn.LeakyReLU(0.1) + + self.layer1 = self._make_layer([32, 64], layers[0]) + self.layer2 = self._make_layer([64, 128], layers[1]) + self.layer3 = self._make_layer([128, 256], layers[2]) + self.layer4 = self._make_layer([256, 512], layers[3]) + self.layer5 = self._make_layer([512, 1024], layers[4]) + + self.layers_out_filters = [64, 128, 256, 512, 1024] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, planes, blocks): + layers = [] + # downsample + layers.append(nn.Conv2d(self.inplanes, planes[1], kernel_size=3, + stride=2, padding=1, bias=False)) + layers.append(nn.BatchNorm2d(planes[1])) + layers.append(nn.LeakyReLU(0.1)) + # blocks + self.inplanes = planes[1] + for i in range(0, blocks): + layers.append(BasicBlock(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + out3 = self.layer3(x) + out4 = self.layer4(out3) + out5 = self.layer5(out4) + + return out3, out4, out5 + +def darknet21(**kwargs): + """Constructs a darknet-21 model. + """ + model = DarkNet([1, 1, 2, 2, 1]) + return model + +def darknet53(**kwargs): + """Constructs a darknet-53 model. + """ + model = DarkNet([1, 2, 8, 8, 4]) + return model diff --git a/nets/model_main.py b/nets/model_main.py new file mode 100644 index 0000000..bba5fcc --- /dev/null +++ b/nets/model_main.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn + +from .backbone import backbone_fn + + +class ModelMain(nn.Module): + def __init__(self, config, is_training=True): + super(ModelMain, self).__init__() + self.config = config + self.training = is_training + self.model_params = config["model_params"] + # backbone + _backbone_fn = backbone_fn[self.model_params["backbone_name"]] + self.backbone = _backbone_fn() + _out_filters = self.backbone.layers_out_filters + # embedding0 + self.embedding0 = self._make_embedding([512, 1024], _out_filters[-1]) + # embedding1 + self.embedding1 = self._make_embedding([256, 512], _out_filters[-2] + 256) + self.embedding1_cbl = self._make_cbl(512, 256, 1) + self.embedding1_upsample = nn.Upsample(scale_factor=2, mode='nearest') + # embedding2 + self.embedding2 = self._make_embedding([128, 256], _out_filters[-3] + 128) + self.embedding2_cbl = self._make_cbl(256, 128, 1) + self.embedding2_upsample = nn.Upsample(scale_factor=2, mode='nearest') + + def _make_cbl(self, _in, _out, ks): + ''' cbl = conv + batch_norm + leaky_relu + ''' + pad = (ks - 1) // 2 if ks else 0 + return nn.Sequential( + nn.Conv2d(_in, _out, kernel_size=ks, stride=1, padding=pad, bias=False), + nn.BatchNorm2d(_out), + nn.LeakyReLU(0.1), + ) + + def _make_embedding(self, filters_list, in_filters): + return nn.ModuleList([ + self._make_cbl(in_filters, filters_list[0], 1), + self._make_cbl(filters_list[0], filters_list[1], 3), + self._make_cbl(filters_list[1], filters_list[0], 1), + self._make_cbl(filters_list[0], filters_list[1], 3), + self._make_cbl(filters_list[1], filters_list[0], 1), + self._make_cbl(filters_list[0], filters_list[1], 3), + self._make_cbl(filters_list[1], 255, 1)]) + + def forward(self, x): + def _branch(_embedding, _in): + for i in range(len(_embedding)): + _in = _embedding[i](_in) + if i == 4: + out_branch = _in + return _in, out_branch + # backbone + x2, x1, x0 = self.backbone(x) + # yolo branch 0 + out0, out0_branch = _branch(self.embedding0, x0) + # yolo branch 1 + x1_in = self.embedding1_cbl(out0_branch) + x1_in = self.embedding1_upsample(x1_in) + x1_in = torch.cat([x1_in, x1], 1) + out1, out1_branch = _branch(self.embedding1, x1_in) + # yolo branch 2 + x2_in = self.embedding2_cbl(out1_branch) + x2_in = self.embedding2_upsample(x2_in) + x2_in = torch.cat([x2_in, x2], 1) + out2, out2_branch = _branch(self.embedding2, x2_in) + return out0, out1, out2 + +if __name__ == "__main__": + config = {"model_params": {"backbone_name": "darknet_53"}} + m = ModelMain(config) + x = torch.randn(1, 3, 416, 416) + y0, y1, y2 = m(x) + print(y0.size()) + print(y1.size()) + print(y2.size()) + diff --git a/nets/yolo_loss.py b/nets/yolo_loss.py new file mode 100644 index 0000000..a08084f --- /dev/null +++ b/nets/yolo_loss.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable + +from common.utils import build_targets + +class YOLOLoss(nn.Module): + def __init__(self, anchors, num_classes, image_dim): + super(YOLOLoss, self).__init__() + self.anchors = anchors + self.scaled_anchors = None + self.num_anchors = len(anchors) + self.num_classes = num_classes + self.bbox_attrs = 5 + num_classes + self.image_dim = image_dim + self.ignore_thres = 0.5 + self.coord_scale = 1 + self.noobject_scale = 1 + self.object_scale = 5 + self.class_scale = 1 + self.seen = 0 + + self.mse_loss = nn.MSELoss() + self.bce_loss = nn.BCELoss() + self.bce_logits_loss = nn.BCEWithLogitsLoss() + + def forward(self, x, targets=None): + bs = x.size(0) + g_dim = x.size(2) + stride = self.image_dim / g_dim + # Tensors for cuda support + FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor + LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor + + prediction = x.view(bs, self.num_anchors, self.bbox_attrs, g_dim, g_dim).permute(0, 1, 3, 4, 2).contiguous() + + # Get outputs + x = torch.sigmoid(prediction[..., 0]) # Center x + y = torch.sigmoid(prediction[..., 1]) # Center y + w = prediction[..., 2] # Width + h = prediction[..., 3] # Height + conf = torch.sigmoid(prediction[..., 4]) # Conf + pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred. + + # Calculate offsets for each grid + grid_x = torch.linspace(0, g_dim-1, g_dim).repeat(g_dim,1).repeat(bs*self.num_anchors, 1, 1).view(x.shape).type(FloatTensor) + grid_y = torch.linspace(0, g_dim-1, g_dim).repeat(g_dim,1).t().repeat(bs*self.num_anchors, 1, 1).view(y.shape).type(FloatTensor) + scaled_anchors = [(a_w / stride, a_h / stride) for a_w, a_h in self.anchors] + anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0])) + anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1])) + anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, g_dim*g_dim).view(w.shape) + anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, g_dim*g_dim).view(h.shape) + + # Add offset and scale with anchors + pred_boxes = FloatTensor(prediction[..., :4].shape) + pred_boxes[..., 0] = x.data + grid_x + pred_boxes[..., 1] = y.data + grid_y + pred_boxes[..., 2] = torch.exp(w.data) * anchor_w + pred_boxes[..., 3] = torch.exp(h.data) * anchor_h + + self.seen += prediction.size(0) + + # Training + if targets is not None: + + if x.is_cuda: + self.mse_loss = self.mse_loss.cuda() + self.bce_loss = self.bce_loss.cuda() + + nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx, ty, tw, th, tconf, tcls = build_targets(pred_boxes.cpu().data, + targets.cpu().data, + scaled_anchors, + self.num_anchors, + self.num_classes, + g_dim, + self.ignore_thres) + + + nProposals = int((conf > 0.25).sum().item()) + + tx = Variable(tx.type(FloatTensor), requires_grad=False) + ty = Variable(ty.type(FloatTensor), requires_grad=False) + tw = Variable(tw.type(FloatTensor), requires_grad=False) + th = Variable(th.type(FloatTensor), requires_grad=False) + tconf = Variable(tconf.type(FloatTensor), requires_grad=False) + tcls = Variable(tcls[cls_mask == 1].type(FloatTensor), requires_grad=False) + coord_mask = Variable(coord_mask.type(FloatTensor), requires_grad=False) + conf_mask = Variable(conf_mask.type(FloatTensor), requires_grad=False) + + loss_x = self.coord_scale * self.mse_loss(x[coord_mask == 1], tx[coord_mask == 1]) / 2 + loss_y = self.coord_scale * self.mse_loss(y[coord_mask == 1], ty[coord_mask == 1]) / 2 + loss_w = self.coord_scale * self.mse_loss(w[coord_mask == 1], tw[coord_mask == 1]) / 2 + loss_h = self.coord_scale * self.mse_loss(h[coord_mask == 1], th[coord_mask == 1]) / 2 + loss_conf = self.bce_loss(conf[conf_mask == 1], tconf[conf_mask == 1]) + loss_cls = self.class_scale * self.bce_loss(pred_cls[cls_mask == 1], tcls) + loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls + + return loss, loss_x.item(), loss_y.item(), loss_w.item(), loss_h.item(), loss_conf.item(), loss_cls.item() + + else: + # If not in training phase return predictions + output = torch.cat((pred_boxes.view(bs, -1, 4) * stride, conf.view(bs, -1, 1), pred_cls.view(bs, -1, self.num_classes)), -1) + return output.data diff --git a/training/params.py b/training/params.py new file mode 100644 index 0000000..57a6905 --- /dev/null +++ b/training/params.py @@ -0,0 +1,35 @@ +TRAINING_PARAMS = \ +{ + "model_params": { + "backbone_name": "darknet_53", + # "backbone_imagenet_pretrain": False, + }, + "yolo": { + "anchors": [[[10, 13], [16, 30], [33, 23]], + [[30, 61], [62, 45], [59, 119]], + [[116, 90], [156, 198], [373, 326]]], + "classes": 80, + }, + "lr": { + "backbone_lr": 0.01, + "other_lr": 0.01, + "freeze_backbone": False, + "decay_gamma": 0.01, + "decay_step": 30, + }, + "optimizer": { + "type": "sgd", + "weight_decay": 4e-05, + }, + "batch_size": 16, + "train_path": "/home/liubofang/bob/YOLOv3_PyTorch/data/coco/trainvalno5k.txt", + "epochs": 100, + "img_h": 416, + "img_w": 416, + "parallels": [4,5,6,7], + "working_dir": "/world/data-c9/liubofang/training/yolo3/pytorch", + "pretrain_snapshot": "", + "evaluate_type": "", + "try": 1060, + "export_onnx": False, +} diff --git a/training/training.py b/training/training.py new file mode 100644 index 0000000..9c744a5 --- /dev/null +++ b/training/training.py @@ -0,0 +1,222 @@ +# coding='utf-8' +import os +import sys +import numpy as np +import time +import datetime +import json +import importlib +import logging +import shutil + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +from tensorboardX import SummaryWriter + +MY_DIRNAME = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(MY_DIRNAME, '..')) +# sys.path.insert(0, os.path.join(MY_DIRNAME, '..', 'evaluate')) +from nets.model_main import ModelMain +from nets.yolo_loss import YOLOLoss +from common.coco_dataset import COCODataset +# from common import model_utils + + +def train(config): + config["global_step"] = config.get("start_step", 0) + is_training = False if config.get("export_onnx") else True + + # Load and initialize network + net = ModelMain(config, is_training=is_training) + net.train(is_training) + + # Optimizer and learning rate + optimizer = _get_optimizer(config, net) + lr_scheduler = optim.lr_scheduler.StepLR( + optimizer, + step_size=config["lr"]["decay_step"], + gamma=config["lr"]["decay_gamma"]) + + # Set data parallel + net = nn.DataParallel(net) + net = net.cuda() + + # Restore pretrain model + # if os.path.exists(config.get("pretrain_snapshot", "")): + # model_utils.restore_model(config["pretrain_snapshot"], net, eval_mode=(is_training==False)) + + # Only export onnx + # if config.get("export_onnx"): + # real_model = net.module + # real_model.eval() + # dummy_input = torch.randn(8, 3, config["img_h"], config["img_w"]).cuda() + # save_path = os.path.join(config["sub_working_dir"], "pytorch.onnx") + # logging.info("Exporting onnx to {}".format(save_path)) + # torch.onnx.export(real_model, dummy_input, save_path, verbose=False) + # logging.info("Done. Exiting now.") + # sys.exit() + + # Evaluate interface + # if config["evaluate_type"]: + # logging.info("Using {} to evaluate model.".format(config["evaluate_type"])) + # evaluate_func = importlib.import_module(config["evaluate_type"]).run_eval + # config["online_net"] = net + + # YOLO loss with 3 scales + for i in range(3): + yolo_losses = YOLOLoss(config["yolo"]["anchors"][0], + config["yolo"]["classes"], config["img_h"]) + + # DataLoader + dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"]), + batch_size=config["batch_size"], + shuffle=True, num_workers=16, pin_memory=False) + + # Start the training loop + logging.info("Start training.") + for epoch in range(config["epochs"]): + for step, (_, images, labels) in enumerate(dataloader): + start_time = time.time() + config["global_step"] += 1 + + # Forward and backward + optimizer.zero_grad() + outputs = net(images) + losses_name = ["total_loss", "x", "y", "w", "h", "conf", "cls"] + losses = [[]] * len(losses_name) + for i in range(3): + _loss_item = yolo_losses(outputs[i], labels) + for j, l in enumerate(_loss_item): + losses[j].append(l) + losses = [sum(l) for l in losses] + loss = losses[0] + loss.backward() + optimizer.step() + + if step > 0 and step % 10 == 0: + _loss = loss.item() + duration = float(time.time() - start_time) + example_per_second = config["batch_size"] / duration + lr = optimizer.param_groups[0]['lr'] + logging.info( + "epoch [%.3d] iter = %d loss = %.2f example/sec = %.3f lr = %.5f "% + (epoch, step, _loss, example_per_second, lr) + ) + config["tensorboard_writer"].add_scalar("lr", + lr, + config["global_step"]) + config["tensorboard_writer"].add_scalar("example/sec", + example_per_second, + config["global_step"]) + for i, name in enumerate(losses_name): + value = _loss if i == 0 else losses[i] + config["tensorboard_writer"].add_scalar(name, + value, + config["global_step"]) + + if step > 0 and step % 1000 == 0: + # net.train(False) + _save_checkpoint(net.state_dict(), config) + # net.train(True) + + lr_scheduler.step() + + # net.train(False) + _save_checkpoint(net.state_dict(), config) + # net.train(True) + logging.info("Bye~") + +# best_eval_result = 0.0 +def _save_checkpoint(state_dict, config, evaluate_func=None): + # global best_eval_result + checkpoint_path = os.path.join(config["sub_working_dir"], "model.pth") + torch.save(state_dict, checkpoint_path) + logging.info("Model checkpoint saved to %s" % checkpoint_path) + # eval_result = evaluate_func(config) + # if eval_result > best_eval_result: + # best_eval_result = eval_result + # logging.info("New best result: {}".format(best_eval_result)) + # best_checkpoint_path = os.path.join(config["sub_working_dir"], 'model_best.pth') + # shutil.copyfile(checkpoint_path, best_checkpoint_path) + # logging.info("Best checkpoint saved to {}".format(best_checkpoint_path)) + # else: + # logging.info("Best result: {}".format(best_eval_result)) + + +def _get_optimizer(config, net): + optimizer = None + + # Assign different lr for each layer + params = None + base_params = list( + map(id, net.backbone.parameters()) + ) + logits_params = filter(lambda p: id(p) not in base_params, net.parameters()) + + if not config["lr"]["freeze_backbone"]: + params = [ + {"params": logits_params, "lr": config["lr"]["other_lr"]}, + {"params": net.backbone.parameters(), "lr": config["lr"]["backbone_lr"]}, + ] + else: + logging.info("freeze backbone's parameters.") + for p in net.backbone.parameters(): + p.requires_grad = False + params = [ + {"params": logits_params, "lr": config["lr"]["other_lr"]}, + ] + + # Initialize optimizer class + if config["optimizer"]["type"] == "adam": + optimizer = optim.Adam(params, weight_decay=config["optimizer"]["weight_decay"]) + elif config["optimizer"]["type"] == "amsgrad": + optimizer = optim.Adam(params, weight_decay=config["optimizer"]["weight_decay"], + amsgrad=True) + elif config["optimizer"]["type"] == "rmsprop": + optimizer = optim.RMSprop(params, weight_decay=config["optimizer"]["weight_decay"]) + else: + # Default to sgd + logging.info("Using SGD optimizer.") + optimizer = optim.SGD(params, momentum=0.9, + weight_decay=config["optimizer"]["weight_decay"], + nesterov=(config["optimizer"]["type"] == "nesterov")) + + return optimizer + +def main(): + logging.basicConfig(level=logging.DEBUG, + format="[%(asctime)s %(filename)s] %(message)s") + + if len(sys.argv) != 2: + logging.error("Usage: python training.py params.py") + sys.exit() + params_path = sys.argv[1] + if not os.path.isfile(params_path): + logging.error("no params file found! path: {}".format(params_path)) + sys.exit() + config = importlib.import_module(params_path[:-3]).TRAINING_PARAMS + config["batch_size"] *= len(config["parallels"]) + + # Create sub_working_dir + sub_working_dir = '{}/{}/size{}x{}_try{}/{}'.format( + config['working_dir'], config['model_params']['backbone_name'], + config['img_w'], config['img_h'], config['try'], + time.strftime("%Y%m%d%H%M%S", time.localtime())) + if not os.path.exists(sub_working_dir): + os.makedirs(sub_working_dir) + config["sub_working_dir"] = sub_working_dir + logging.info("sub working dir: %s" % sub_working_dir) + + # Creat tf_summary writer + config["tensorboard_writer"] = SummaryWriter(sub_working_dir) + logging.info("Please using 'python -m tensorboard.main --logdir={}'".format(sub_working_dir)) + + # Start training + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, config["parallels"])) + train(config) + +if __name__ == "__main__": + main()