-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_autofocus.py
52 lines (42 loc) · 1.74 KB
/
model_autofocus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import cv2
import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Linear, Dropout, Hardswish
from torchvision.transforms import transforms
model = torchvision.models.mobilenet_v3_small()
def rename_attribute(obj, old_name, new_name):
obj._modules[new_name] = obj._modules.pop(old_name)
# class Mobilnetv3Regressor(torchvision.models.mobilenet_v3_small):
# def __init__(self):
# self.classifier = nn.Sequential(Linear(in_features=576, out_features=1024),
# Hardswish(),
# Dropout(p=0.2),
# Linear(in_features=1024, out_features=1))
if __name__ == "__main__":
image = Image.open(r"C:\Users\tristan_cotte\PycharmProjects\prior_controller\autofocus\sly_project\ds0\img\1976_1976_10.jpg")
# Define a transform to convert the image to tensor
transform = transforms.ToTensor()
# Convert the image to PyTorch tensor
tensor = transform(image)
tensor = torch.unsqueeze(tensor, dim=0)
print(tensor.shape)
# print(model)
# #
# # model.features[-1] = nn.Linear(in_features=576, out_features=1)
# # print(model)
# for n, m in model.named_modules():
# print(n, m)
#
# print(model.classifier)
#
# print(model.__getattr__('classifier'))
model.classifier = nn.Sequential(Linear(in_features=576, out_features=1024),
Hardswish(),
Dropout(p=0.2),
Linear(in_features=1024, out_features=1))
# rename_attribute(model, 'classifier', 'regressor')
print(model)
with torch.no_grad():
print(model(tensor))