-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_segmentation.py
219 lines (164 loc) · 8.03 KB
/
image_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from tkinter import *
from tkinter import filedialog
from idlelib.tooltip import Hovertip
import threading
from transformers import SamModel, SamProcessor
from transformers import pipeline
import torch
import time
from diffusers.utils import load_image, make_image_grid
import numpy as np
from PIL import Image
from controls import create_toolbar_button, create_number_control
DEBUG = False
def show_masks_on_image(raw_image_path, masks, transparency=1.0):
# Internal method to convert the mask to a transparent image
def get_mask_image(mask):
# Get a random color and add some transparency
color = np.concatenate([np.random.random(3), np.array([transparency])], axis=0)
# Height and width are the last two dimensions
h, w = mask.shape[-2:]
# Apply the color to the mask
mask_image = (255*mask).reshape(h, w, 1) * color.reshape(1, 1, -1)
# Construct the PIL image
mask_image = Image.fromarray(mask_image.astype('uint8'), mode='RGBA')
return mask_image
# Load the current image
raw_image = Image.open(raw_image_path).convert("RGBA")
# Draw semi-transparent masks in different colors
for mask in masks:
m = get_mask_image(mask)
raw_image.paste(m, (0,0), m)
return raw_image
class image_segmentation_ui:
def __init__(self, parent, history, width=512, height=512):
# Efficient attention is not native in old PyTorch versions and is needed to reduce GPU usage
self.use_efficient_attention = int(torch.__version__.split('.')[0]) < 2
# Use GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Size of image to work with
self.width = width
self.height = height
# Used to store generated images
self.history = history
# Get frames needed for layout
toolbar, left_frame, right_frame = self.create_layout(parent)
# Populate controls
self.initialize_toolbar(toolbar)
self.initialize_prompts(right_frame)
self.initialize_canvas(left_frame)
self.update_controls()
def create_layout(self, parent):
# Create toolbar
toolbar = Frame(parent, width=2*self.width, height=20, bg='light grey')
toolbar.pack(side=TOP, fill=X, expand=False)
# Create left frame
left_frame = Frame(parent, width=self.width, height=self.height, bg='grey')
left_frame.pack(side=LEFT, fill=BOTH, expand=False)
# Create right frame
right_frame = Frame(parent, width=self.width, height=self.height, bg='grey')
right_frame.pack(side=RIGHT, fill=BOTH, expand=True)
return toolbar, left_frame, right_frame
def initialize_canvas(self, parent):
# Create canvas
self.canvas = Canvas(parent, bg="black", width=self.width, height=self.height)
self.canvas.pack(fill=BOTH, expand=False)
def initialize_prompts(self, parent):
# Create text box for entering the prompt
prompt = ""
Label(parent, text="Prompt:", anchor=W).pack(side=TOP, fill=X, expand=False)
self.prompt = Text(parent, height=1, wrap=WORD)
self.prompt.insert(END, prompt)
self.prompt.pack(side=TOP, fill=BOTH, expand=True)
def initialize_toolbar(self, toolbar):
# Create combo box for selecting a diffusion model
checkpoint_frame = Frame(toolbar, bg='grey')
checkpoint_options = ["Segment Anything"]
self.checkpoint = StringVar(checkpoint_frame, checkpoint_options[0])
Hovertip(checkpoint_frame, 'Select the model to use')
Label(checkpoint_frame, text="Model", anchor=W).pack(side=LEFT, fill=Y, expand=False)
checkpoint_menu = OptionMenu(checkpoint_frame, self.checkpoint, *checkpoint_options)
checkpoint_menu.config(width=20)
checkpoint_menu.pack(side=LEFT, fill=X, expand=True)
checkpoint_frame.pack(side=LEFT, fill=X, expand=False)
# Create textbox for entering the transparency of the segmentation
self.transparency_entry = create_number_control(toolbar, 0.6, "Transparency", 'Enter a value from 0 to 1. Higher values generally result in higher quality images but take longer.', increment=.05, type=float, min=0, max=1)
# Create a button to segment the image
self.segment_button = create_toolbar_button(toolbar, 'Segment', self.segment, 'Segment the image')
# Create a button to caption the image
self.segment_button = create_toolbar_button(toolbar, 'Caption Image', self.caption, 'Caption the image')
def refresh_ui(self):
if len(self.history) > 0:
self.canvas_bg = PhotoImage(file=self.history[-1])
self.width, self.height = self.canvas_bg.width(), self.canvas_bg.height()
self.canvas.config(width=self.width, height=self.height)
self.canvas.create_image(0, 0, image=self.canvas_bg, anchor=NW)
else:
self.canvas.delete("all")
self.update_controls()
def update_controls(self):
self.prompt["state"] = DISABLED
self.prompt['bg'] = '#D3D3D3'
def update_canvas_image(self, image):
self.history.append('history/{}.png'.format(time.time()))
image.save(self.history[-1])
self.refresh_ui()
def segment(self):
if DEBUG:
self.segment_thread()
else:
threading.Thread(target=self.segment_thread).start()
def segment_thread(self):
# Get all necessary arguments from UI
prompt = self.prompt.get('1.0', 'end-1 chars')
model_name = self.checkpoint.get()
transparency = float(self.transparency_entry.get())
if len(self.history) > 0:
init_image = load_image(self.history[-1])
else:
# If no image use all black (noise didn't work as well *np.random.randint(0, 255, (self.height, self.width, 3), "uint8")*)
init_image = Image.fromarray(np.zeros((self.width, self.height, 3), 'uint8'))
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=0)
outputs = generator(init_image, points_per_batch=64)
masks = outputs["masks"]
image = show_masks_on_image(self.history[-1], masks, transparency)
self.update_canvas_image(image)
# Use to validate inputs and outputs
if DEBUG:
print(prompt)
print(model_name)
def remove(self):
if DEBUG:
self.caption()
else:
threading.Thread(target=self.caption).start()
def caption(self):
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
text = captioner(self.history[-1])
self.prompt["state"] = NORMAL
self.prompt.insert(END, '\n'+text[0]['generated_text'])
self.prompt["state"] = DISABLED
# Experiments with background removal
return
mm_pipeline = pipeline("image-to-text",model="llava-hf/llava-1.5-7b-hf")
text = mm_pipeline("https://huggingface.co/spaces/llava-hf/llava-4bit/resolve/main/examples/baklava.png", "How to make this pastry?")
self.prompt.set(text)
return
# Get all necessary arguments from UI
prompt = self.prompt.get('1.0', 'end-1 chars')
model_name = self.checkpoint.get()
if len(self.history) > 0:
init_image = load_image(self.history[-1])
else:
# If no image use all black (noise didn't work as well *np.random.randint(0, 255, (self.height, self.width, 3), "uint8")*)
init_image = Image.fromarray(np.zeros((self.width, self.height, 3), 'uint8'))
image_path = self.history[-1]
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
pillow_mask = pipe(image_path, return_mask = True) # outputs a pillow mask
self.update_canvas_image(pillow_mask)
pillow_image = pipe(image_path)
self.update_canvas_image(pillow_image)
# Use to validate inputs and outputs
if DEBUG:
print(prompt)
print(model_name)