-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodels.py
100 lines (79 loc) · 2.39 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from .blocks import *
__all__ = [
"ResNet", "ResNet_18", "ResNet_34",
"ResNet_50", "ResNet_101", "ResNet_152"
]
RESNET_TYPE = {
'18': [2, 2, 2, 2],
'34': [3, 4, 6, 3],
'50': [3, 4, 6, 3],
'101': [3, 4, 23, 3],
'152': [3, 8, 36, 3],
}
class ResNet(nn.Module):
def __init__(
self,
model_type: str,
image_channels: int,
num_classes: int,
dropout_rate: float = 0.5
):
super().__init__()
dim = 64
num_layers = int(model_type)
layers = []
layers += [ConvBlock(image_channels, dim, kernel_size=7, stride=2, padding=3)]
layers += [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]
# stack blocks
listBlocks = RESNET_TYPE[model_type]
for idx_layer, nblock in enumerate(listBlocks):
layers += [ResidualBlock(num_layers, idx_layer, nblock, dim)]
dim *= 2
layers += [nn.AdaptiveAvgPool2d(1)]
if num_layers < 50:
dim = dim // 2
else:
dim = dim * 2
self.feature_extractor = nn.Sequential(*layers)
self.classifier = Classifier(
in_features=int(dim),
out_features=num_classes,
dropout_rate=dropout_rate
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.feature_extractor(x)
x = torch.flatten(x, 1)
logits = self.classifier(x)
return logits
def ResNet_18(
image_channels: int,
num_classes: int,
dropout_rate: float = 0.5
) -> ResNet:
return ResNet('18', image_channels, num_classes, dropout_rate)
def ResNet_34(
image_channels: int,
num_classes: int,
dropout_rate: float = 0.5
) -> ResNet:
return ResNet('34', image_channels, num_classes, dropout_rate)
def ResNet_50(
image_channels: int,
num_classes: int,
dropout_rate: float = 0.5
) -> ResNet:
return ResNet('50', image_channels, num_classes, dropout_rate)
def ResNet_101(
image_channels: int,
num_classes: int,
dropout_rate: float = 0.5
) -> ResNet:
return ResNet('101', image_channels, num_classes, dropout_rate)
def ResNet_152(
image_channels: int,
num_classes: int,
dropout_rate: float = 0.5
) -> ResNet:
return ResNet('152', image_channels, num_classes, dropout_rate)