Skip to content

Commit

Permalink
Initial
Browse files Browse the repository at this point in the history
  • Loading branch information
cleardusk committed Jun 29, 2018
0 parents commit 2595400
Show file tree
Hide file tree
Showing 8 changed files with 883 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.idea/
*.pyc
102 changes: 102 additions & 0 deletions ddfa_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env python3
# coding: utf-8

import os.path as osp
from pathlib import Path
import numpy as np

import torch
import torch.utils.data as data
import cv2
import pickle
import argparse
from io_utils import _numpy_to_tensor, _load_cpu, _load_gpu


def img_loader(path):
return cv2.imread(path, cv2.IMREAD_COLOR)


def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected')


def _parse_param(param):
"""Work for both numpy and tensor"""
p_ = param[:12].reshape(3, -1)
p = p_[:, :3]
offset = p_[:, -1].reshape(3, 1)
alpha_shp = param[12:52].reshape(-1, 1)
alpha_exp = param[52:].reshape(-1, 1)
return p, offset, alpha_shp, alpha_exp


class AverageMeter(object):
"""Computes and stores the average and current value"""

def __init__(self):
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count


class ToTensorGjz(object):
def __call__(self, pic):
if isinstance(pic, np.ndarray):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
return img.float()

def __repr__(self):
return self.__class__.__name__ + '()'


class NormalizeGjz(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, tensor):
tensor.sub_(self.mean).div_(self.std)
return tensor


class DDFADataset(data.Dataset):
def __init__(self, root, filelists, param_fp, transform=None, **kargs):
self.root = root
self.transform = transform
self.lines = Path(filelists).read_text().strip().split('\n')
self.params = _numpy_to_tensor(_load_cpu(param_fp))
self.img_loader = img_loader

def _target_loader(self, index):
target = self.params[index]

return target

def __getitem__(self, index):
path = osp.join(self.root, self.lines[index])
img = self.img_loader(path)

target = self._target_loader(index)

if self.transform is not None:
img = self.transform(img)
return img, target

def __len__(self):
return len(self.lines)
64 changes: 64 additions & 0 deletions io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
# coding: utf-8

import os
import numpy as np
import torch
import pickle


def mkdir(d):
if not os.path.isdir(d) and not os.path.exists(d):
os.system(f'mkdir -p {d}')


def _get_suffix(filename):
"""a.jpg -> jpg"""
pos = filename.rfind('.')
if pos == -1:
return ''
return filename[pos + 1:]


def _load(fp):
suffix = _get_suffix(fp)
if suffix == 'npy':
return np.load(fp)
elif suffix == 'pkl':
return pickle.load(open(fp, 'rb'))


def _dump(wfp, obj):
suffix = _get_suffix(wfp)
if suffix == 'npy':
np.save(wfp, obj)
elif suffix == 'pkl':
pickle.dump(obj, open(wfp, 'wb'))
else:
raise Exception(f'Unknown Type: {suffix}')


def _load_tensor(fp, mode='cpu'):
if mode.lower() == 'cpu':
return torch.from_numpy(_load(fp))
elif mode.lower() == 'gpu':
return torch.from_numpy(_load(fp)).cuda()


def _tensor_to_cuda(x):
if x.is_cuda:
return x
else:
return x.cuda()


def _load_gpu(fp):
return torch.from_numpy(_load(fp)).cuda()


_load_cpu = _load
_numpy_to_tensor = lambda x: torch.from_numpy(x)
_tensor_to_numpy = lambda x: x.cpu()
_numpy_to_cuda = lambda x: _tensor_to_cuda(torch.from_numpy(x))
_cuda_to_tensor = lambda x: x.cpu()
_cuda_to_numpy = lambda x: x.cpu().numpy()
154 changes: 154 additions & 0 deletions mobilenet_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#!/usr/bin/env python3
# coding: utf-8

from __future__ import division

"""
Creates a MobileNet Model as defined in:
Andrew G. Howard Menglong Zhu Bo Chen, et.al. (2017).
MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications.
Copyright (c) Yang Lu, 2017
Modified By cleardusk
"""
import math
import torch.nn as nn

__all__ = ['mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 'mobilenet_025']


class DepthWiseBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, prelu=False):
super(DepthWiseBlock, self).__init__()
inplanes, planes = int(inplanes), int(planes)
self.conv_dw = nn.Conv2d(inplanes, inplanes, kernel_size=3, padding=1, stride=stride, groups=inplanes,
bias=False)
self.bn_dw = nn.BatchNorm2d(inplanes)
self.conv_sep = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)
self.bn_sep = nn.BatchNorm2d(planes)
if prelu:
self.relu = nn.PReLU()
else:
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
out = self.conv_dw(x)
out = self.bn_dw(out)
out = self.relu(out)

out = self.conv_sep(out)
out = self.bn_sep(out)
out = self.relu(out)

return out


class MobileNet(nn.Module):
def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, input_channel=3):
""" Constructor
Args:
widen_factor: config of widen_factor
num_classes: number of classes
"""
super(MobileNet, self).__init__()

block = DepthWiseBlock
self.conv1 = nn.Conv2d(input_channel, int(32 * widen_factor), kernel_size=3, stride=2, padding=1,
bias=False)

self.bn1 = nn.BatchNorm2d(int(32 * widen_factor))
if prelu:
self.relu = nn.PReLU()
else:
self.relu = nn.ReLU(inplace=True)

self.dw2_1 = block(32 * widen_factor, 64 * widen_factor, prelu=prelu)
self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2, prelu=prelu)

self.dw3_1 = block(128 * widen_factor, 128 * widen_factor, prelu=prelu)
self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2, prelu=prelu)

self.dw4_1 = block(256 * widen_factor, 256 * widen_factor, prelu=prelu)
self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2, prelu=prelu)

self.dw5_1 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)
self.dw5_2 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)
self.dw5_3 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)
self.dw5_4 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)
self.dw5_5 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu)
self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2, prelu=prelu)

self.dw6 = block(1024 * widen_factor, 1024 * widen_factor, prelu=prelu)

self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(int(1024 * widen_factor), num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)

x = self.dw2_1(x)
x = self.dw2_2(x)
x = self.dw3_1(x)
x = self.dw3_2(x)
x = self.dw4_1(x)
x = self.dw4_2(x)
x = self.dw5_1(x)
x = self.dw5_2(x)
x = self.dw5_3(x)
x = self.dw5_4(x)
x = self.dw5_5(x)
x = self.dw5_6(x)
x = self.dw6(x)

x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)

return x


def mobilenet(widen_factor=1.0, num_classes=1000):
"""
Construct MobileNet.
widen_factor=1.0 for mobilenet_1
widen_factor=0.75 for mobilenet_075
widen_factor=0.5 for mobilenet_05
widen_factor=0.25 for mobilenet_025
"""
model = MobileNet(widen_factor=widen_factor, num_classes=num_classes)
return model


def mobilenet_2(num_classes=62, input_channel=3):
model = MobileNet(widen_factor=2.0, num_classes=num_classes, input_channel=input_channel)
return model


def mobilenet_1(num_classes=62, input_channel=3):
model = MobileNet(widen_factor=1.0, num_classes=num_classes, input_channel=input_channel)
return model


def mobilenet_075(num_classes=62, input_channel=3):
model = MobileNet(widen_factor=0.75, num_classes=num_classes, input_channel=input_channel)
return model


def mobilenet_05(num_classes=62, input_channel=3):
model = MobileNet(widen_factor=0.5, num_classes=num_classes, input_channel=input_channel)
return model


def mobilenet_025(num_classes=62, input_channel=3):
model = MobileNet(widen_factor=0.25, num_classes=num_classes, input_channel=input_channel)
return model
3 changes: 3 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Face Alignment in Full Pose Range: A 3D Total Solution

The pytorch implementation of paper [Face Alignment in Full Pose Range: A 3D Total Solution](https://arxiv.org/abs/1804.01005).
Loading

0 comments on commit 2595400

Please sign in to comment.