-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
127 lines (101 loc) · 4.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#! /usr/bin/env python
import os
import cv2
import dlib
import argparse
import numpy as np
from face_detection import face_detection
from face_points_detection import face_points_detection
from face_swap import warp_image_2d, warp_image_3d, mask_from_points, apply_mask, correct_colours, transformation_from_points
def select_face(im, r=10):
faces = face_detection(im)
if len(faces) == 0:
print('Detect 0 Face !!!')
exit(-1)
if len(faces) == 1:
bbox = faces[0]
else:
bbox = []
def click_on_face(event, x, y, flags, params):
if event != cv2.EVENT_LBUTTONDOWN:
return
for face in faces:
if face.left() < x < face.right() and face.top() < y < face.bottom():
bbox.append(face)
break
im_copy = im.copy()
for face in faces:
# draw the face bounding box
cv2.rectangle(im_copy, (face.left(), face.top()), (face.right(), face.bottom()), (0, 0, 255), 1)
cv2.imshow('Click the Face:', im_copy)
cv2.setMouseCallback('Click the Face:', click_on_face)
while len(bbox) == 0:
cv2.waitKey(1)
cv2.destroyAllWindows()
bbox = bbox[0]
points = np.asarray(face_points_detection(im, bbox))
im_w, im_h = im.shape[:2]
left, top = np.min(points, 0)
right, bottom = np.max(points, 0)
x, y = max(0, left-r), max(0, top-r)
w, h = min(right+r, im_h)-x, min(bottom+r, im_w)-y
return points - np.asarray([[x, y]]), (x, y, w, h), im[y:y+h, x:x+w]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='FaceSwapApp')
parser.add_argument('--src', required=True, help='Path for source image')
parser.add_argument('--dst', required=True, help='Path for target image')
parser.add_argument('--out', required=True, help='Path for storing output images')
parser.add_argument('--warp_2d', default=False, action='store_true', help='2d or 3d warp')
parser.add_argument('--correct_color', default=False, action='store_true', help='Correct color')
args = parser.parse_args()
# Read images
src_img = cv2.imread(args.src)
dst_img = cv2.imread(args.dst)
# Select src face
src_points, src_shape, src_face = select_face(src_img)
# Select dst face
dst_points, dst_shape, dst_face = select_face(dst_img)
w, h = dst_face.shape[:2]
### Warp Image
if not args.warp_2d:
## 3d warp
warped_src_face = warp_image_3d(src_face, src_points[:48], dst_points[:48], (w, h))
else:
## 2d warp
src_mask = mask_from_points(src_face.shape[:2], src_points)
src_face = apply_mask(src_face, src_mask)
# Correct Color for 2d warp
if args.correct_color:
warped_dst_img = warp_image_3d(dst_face, dst_points[:48], src_points[:48], src_face.shape[:2])
src_face = correct_colours(warped_dst_img, src_face, src_points)
# Warp
warped_src_face = warp_image_2d(src_face, transformation_from_points(dst_points, src_points), (w, h, 3))
## Mask for blending
mask = mask_from_points((w, h), dst_points)
mask_src = np.mean(warped_src_face, axis=2) > 0
mask = np.asarray(mask*mask_src, dtype=np.uint8)
## Correct color
if not args.warp_2d and args.correct_color:
warped_src_face = apply_mask(warped_src_face, mask)
dst_face_masked = apply_mask(dst_face, mask)
warped_src_face = correct_colours(dst_face_masked, warped_src_face, dst_points)
## Shrink the mask
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
##Poisson Blending
r = cv2.boundingRect(mask)
center = ((r[0] + int(r[2] / 2), r[1] + int(r[3] / 2)))
output = cv2.seamlessClone(warped_src_face, dst_face, mask, center, cv2.NORMAL_CLONE)
x, y, w, h = dst_shape
dst_img_cp = dst_img.copy()
dst_img_cp[y:y+h, x:x+w] = output
output = dst_img_cp
dir_path = os.path.dirname(args.out)
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
cv2.imwrite(args.out, output)
##For debug
cv2.imshow("From", dst_img)
cv2.imshow("To", output)
cv2.waitKey(0)
cv2.destroyAllWindows()