Skip to content

Commit

Permalink
Refine video_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
cleardusk committed Jun 17, 2019
1 parent b28ac6c commit 6613246
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,37 @@
import numpy as np
import cv2
import dlib
from utils.ddfa import ToTensorGjz, NormalizeGjz, str2bool
from utils.ddfa import ToTensorGjz, NormalizeGjz
import scipy.io as sio
from utils.inference import (
get_suffix,
parse_roi_box_from_landmark,
crop_img,
predict_68pts,
dump_to_ply,
dump_vertex,
draw_landmarks,
predict_dense,
parse_roi_box_from_bbox,
get_colors,
write_obj_with_colors,
)
from utils.cv_plot import plot_pose_box
from utils.estimate_pose import parse_pose
from utils.cv_plot import plot_kpt
from utils.render import get_depths_image, cget_depths_image, cpncc
from utils.paf import gen_img_paf
import argparse
import torch.backends.cudnn as cudnn

STD_SIZE = 120

def main(args):

def main(args):
# 0. open video
vc = cv2.VideoCapture(str(args.video) if len(args.video) == 1 else args.video)
# vc = cv2.VideoCapture(str(args.video) if len(args.video) == 1 else args.video)
vc = cv2.VideoCapture(args.video if int(args.video) != 0 else 0)

# 1. load pre-tained model
checkpoint_fp = "models/phase1_wpdc_vdc.pth.tar"
arch = "mobilenet_1"
checkpoint_fp = 'models/phase1_wpdc_vdc.pth.tar'
arch = 'mobilenet_1'

tri = sio.loadmat('visualize/tri.mat')['tri']
transform = transforms.Compose([ToTensorGjz(), NormalizeGjz(mean=127.5, std=128)])

checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)[
"state_dict"
'state_dict'
]
model = getattr(mobilenet_v1, arch)(
num_classes=62
Expand All @@ -52,22 +45,22 @@ def main(args):
model_dict = model.state_dict()
# because the model is trained by multiple gpus, prefix module should be removed
for k in checkpoint.keys():
model_dict[k.replace("module.", "")] = checkpoint[k]
model_dict[k.replace('module.', '')] = checkpoint[k]
model.load_state_dict(model_dict)
if args.mode == "gpu":
if args.mode == 'gpu':
cudnn.benchmark = True
model = model.cuda()
model.eval()

# 2. load dlib model for face detection and landmark used for face cropping
dlib_landmark_model = "models/shape_predictor_68_face_landmarks.dat"
dlib_landmark_model = 'models/shape_predictor_68_face_landmarks.dat'
face_regressor = dlib.shape_predictor(dlib_landmark_model)
face_detector = dlib.get_frontal_face_detector()

# 3. forward
success, frame = vc.read()
last_frame_pts = []

while success:
if len(last_frame_pts) == 0:
rects = face_detector(frame, 1)
Expand All @@ -85,7 +78,7 @@ def main(args):
)
input = transform(img).unsqueeze(0)
with torch.no_grad():
if args.mode == "gpu":
if args.mode == 'gpu':
input = input.cuda()
param = model(input)
param = param.squeeze().cpu().numpy().flatten().astype(np.float32)
Expand All @@ -96,21 +89,21 @@ def main(args):

pncc = cpncc(frame, vertices_lst, tri - 1) / 255.0
frame = frame / 255.0 * (1.0 - pncc)
cv2.imshow("3ddfa", frame)
cv2.imshow('3ddfa', frame)
cv2.waitKey(1)
success, frame = vc.read()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="3DDFA inference pipeline")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='3DDFA inference pipeline')
parser.add_argument(
"-v",
"--video",
default="0",
'-v',
'--video',
default='0',
type=str,
help="video file path or opencv cam index",
help='video file path or opencv cam index',
)
parser.add_argument("-m", "--mode", default="cpu", type=str, help="gpu or cpu mode")
parser.add_argument('-m', '--mode', default='cpu', type=str, help='gpu or cpu mode')

args = parser.parse_args()
main(args)

0 comments on commit 6613246

Please sign in to comment.