-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpoly_selection.py
72 lines (50 loc) · 2.17 KB
/
poly_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.widgets import PolygonSelector
from matplotlib.path import Path
from matplotlib import colors
#----------------polygonal selection----------------
class PolygonSelect:
'''
PolygonSelect class
shows the RGB image and allows to draw a polygon on it by calling PolygonSelector from matplotlib
vertexes of the polygon are stored in verts attribute and can be then accessed with .verts
'''
#class initialization, show the plot and call PolygonSelector function from matplotlib
def __init__(self, RGB):
fig, ax = plt.subplots(figsize=[10,10])
ax.imshow(RGB)
print("Click on the figure to create a polygon.")
print("Press the 'esc' key to start a new polygon.")
print("Try holding the 'shift' key to move all of the vertices.")
print("Try holding the 'ctrl' key to move a single vertex.")
self.verts=[]
self.poly = PolygonSelector(ax, self.onselect, props = dict(color='r', linewidth=2))
#onselect function, required by PolygonSelector
def onselect(self, verts):
self.verts=verts
#---------------
#references
#https://matplotlib.org/stable/gallery/widgets/polygon_selector_demo.html
#visualizing polygonal selection
def CreateMask(polygon, RGB):
'''
Function to create the polygonal mask.
Input: polygon's vertexes and the RGB image over which you want to visualize the polygon.
Output: mask (array with True and False values).
'''
# Create a mask based on the path
path = Path(polygon.verts)
#creating x and y coordinates for every point in the image
#y dimension 744, x dimension 781
y, x = np.mgrid[:RGB.shape[0], :RGB.shape[1]]
points = np.vstack((x.ravel(), y.ravel())).T
#creating the mask
mask = path.contains_points(points) #points must be (N,2) array
mask = mask.reshape(RGB.shape[:2]) #reshape
#plot
cmap = colors.ListedColormap(['none', 'red']) #plot colors
fig, ax = plt.subplots(figsize=[10,10])
ax.imshow(RGB)
ax.imshow(mask, alpha = 0.5, cmap=cmap)
return mask