-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple formatting with Black, CPU support for inference and forgotten main function in training script #3
base: master
Are you sure you want to change the base?
Changes from 4 commits
1999b55
39e6fb1
db3b54e
0bf0901
941b8dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
__pycache__ | ||
*/__pycache__ | ||
**/__pycache__ | ||
saved_models/ | ||
.vscode | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,116 +1,129 @@ | ||
import os | ||
import glob | ||
import time | ||
|
||
import numpy as np | ||
from PIL import Image | ||
from skimage import io, transform | ||
|
||
import torch | ||
import torchvision | ||
from torch.autograd import Variable | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision import transforms#, utils | ||
# import torch.optim as optim | ||
|
||
import numpy as np | ||
from PIL import Image | ||
import glob | ||
# import torch.optim as optim | ||
import torchvision | ||
from torchvision import transforms # , utils | ||
Comment on lines
+2
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prettier import statements |
||
|
||
from data_loader import RescaleT | ||
from data_loader import ToTensor | ||
from data_loader import ToTensorLab | ||
from data_loader import SalObjDataset | ||
|
||
from model import U2NET # full size version 173.6 MB | ||
from model import U2NETP # small version u2net 4.7 MB | ||
from model import U2NET # full size version 173.6 MB | ||
from model import U2NETP # small version u2net 4.7 MB | ||
|
||
|
||
# normalize the predicted SOD probability map | ||
def normPRED(d): | ||
ma = torch.max(d) | ||
mi = torch.min(d) | ||
|
||
dn = (d-mi)/(ma-mi) | ||
dn = (d - mi) / (ma - mi) | ||
|
||
return dn | ||
|
||
def save_output(image_name,pred,d_dir): | ||
|
||
def save_output(image_name, pred, d_dir): | ||
|
||
predict = pred | ||
predict = predict.squeeze() | ||
predict_np = predict.cpu().data.numpy() | ||
|
||
im = Image.fromarray(predict_np*255).convert('RGB') | ||
im = Image.fromarray(predict_np * 255).convert("RGB") | ||
img_name = image_name.split("/")[-1] | ||
image = io.imread(image_name) | ||
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) | ||
imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR) | ||
|
||
pb_np = np.array(imo) | ||
|
||
aaa = img_name.split(".") | ||
bbb = aaa[0:-1] | ||
imidx = bbb[0] | ||
for i in range(1,len(bbb)): | ||
for i in range(1, len(bbb)): | ||
imidx = imidx + "." + bbb[i] | ||
|
||
imo.save(d_dir+imidx+'.png') | ||
imo.save(d_dir + imidx + ".png") | ||
|
||
|
||
def main(): | ||
|
||
# --------- 1. get image path and name --------- | ||
model_name='u2net'#u2netp | ||
model_name = "u2net" # u2netp | ||
|
||
image_dir = "./test_data/test_images/" | ||
prediction_dir = "./test_data/" + model_name + "_results/" | ||
model_dir = "./saved_models/" + model_name + "/" + model_name + ".pth" | ||
|
||
image_dir = './test_data/test_images/' | ||
prediction_dir = './test_data/' + model_name + '_results/' | ||
model_dir = './saved_models/'+ model_name + '/' + model_name + '.pth' | ||
|
||
img_name_list = glob.glob(image_dir + '*') | ||
img_name_list = glob.glob(image_dir + "*") | ||
print(img_name_list) | ||
|
||
# --------- 2. dataloader --------- | ||
#1. dataloader | ||
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, | ||
lbl_name_list = [], | ||
transform=transforms.Compose([RescaleT(320), | ||
ToTensorLab(flag=0)]) | ||
) | ||
test_salobj_dataloader = DataLoader(test_salobj_dataset, | ||
batch_size=1, | ||
shuffle=False, | ||
num_workers=1) | ||
# 1. dataloader | ||
test_salobj_dataset = SalObjDataset( | ||
img_name_list=img_name_list, | ||
lbl_name_list=[], | ||
transform=transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]), | ||
) | ||
test_salobj_dataloader = DataLoader( | ||
test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1 | ||
) | ||
|
||
# --------- 3. model define --------- | ||
if(model_name=='u2net'): | ||
if model_name == "u2net": | ||
print("...load U2NET---173.6 MB") | ||
net = U2NET(3,1) | ||
elif(model_name=='u2netp'): | ||
net = U2NET(3, 1) | ||
elif model_name == "u2netp": | ||
print("...load U2NEP---4.7 MB") | ||
net = U2NETP(3,1) | ||
net.load_state_dict(torch.load(model_dir)) | ||
net = U2NETP(3, 1) | ||
|
||
if torch.cuda.is_available(): | ||
net.load_state_dict(torch.load(model_dir)) | ||
net.cuda() | ||
else: | ||
net.load_state_dict(torch.load(model_dir, map_location=torch.device("cpu"))) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you try to load_state_dict on CPU without mapping location to CPU, you will have a RuntimeError There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have tried this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MatiasConTilde this should also work on torch>=0.4 with map_location="cpu" |
||
net.eval() | ||
|
||
# --------- 4. inference for each image --------- | ||
for i_test, data_test in enumerate(test_salobj_dataloader): | ||
|
||
print("inferencing:",img_name_list[i_test].split("/")[-1]) | ||
start = time.time() | ||
|
||
inputs_test = data_test['image'] | ||
inputs_test = data_test["image"] | ||
inputs_test = inputs_test.type(torch.FloatTensor) | ||
|
||
if torch.cuda.is_available(): | ||
inputs_test = Variable(inputs_test.cuda()) | ||
else: | ||
inputs_test = Variable(inputs_test) | ||
|
||
d1,d2,d3,d4,d5,d6,d7= net(inputs_test) | ||
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) | ||
|
||
print( | ||
f"Predicted {os.path.basename(img_name_list[i_test])} in {time.time() - start:.2f}s" | ||
) | ||
Comment on lines
+102
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# normalization | ||
pred = d1[:,0,:,:] | ||
pred = d1[:, 0, :, :] | ||
pred = normPRED(pred) | ||
|
||
# save results to test_results folder | ||
save_output(img_name_list[i_test],pred,prediction_dir) | ||
save_output(img_name_list[i_test], pred, prediction_dir) | ||
|
||
del d1, d2, d3, d4, d5, d6, d7 | ||
|
||
del d1,d2,d3,d4,d5,d6,d7 | ||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple .gitignore