forked from rapidrabbit76/Classification-For-Everyone
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
90 lines (70 loc) · 2.32 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Final, Dict
import torch
from torch import nn
from .blocks import Classifier, ConvBlock
__all__ = ["VGGModel", "VGG11", "VGG13", "VGG16", "VGG19"]
MODEL_TYPES: Final[Dict] = {
# fmt: off
"11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M",
512, 512, "M"],
"13": [64, 64, "M", 128, 128, "M", 256, 256, "M",
512, 512, "M", 512, 512, "M"],
"16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M",
512, 512, 512, "M", 512, 512, 512, "M"],
"19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M",
512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}
class VGGModel(nn.Module):
def __init__(
self,
model_type: str = "VGG11",
image_channels: int = 3,
num_classes: int = 1000,
dropout_rate: int = 0.5,
) -> None:
super().__init__()
layers = []
assert (
model_type in MODEL_TYPES.keys()
), f"{model_type} is not in {' '.join(MODEL_TYPES.keys())}"
in_channels = image_channels
for x in MODEL_TYPES[model_type]:
if type(x) == int:
layers.append(ConvBlock(in_channels, x))
in_channels = x
elif x == "M":
layers.append(nn.MaxPool2d(2, 2))
self.feature_extractor = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(7)
self.classfier = nn.Sequential(
nn.Flatten(),
Classifier(512 * 7 * 7, 4096, num_classes, dropout_rate),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.feature_extractor(x)
x = self.avgpool(x)
return self.classfier(x)
def VGG11(
image_channals: int,
num_classes: int,
dropout_rate: int = 0.5,
):
return VGGModel("11", image_channals, num_classes, dropout_rate)
def VGG13(
image_channals: int,
num_classes: int,
dropout_rate: int = 0.5,
):
return VGGModel("13", image_channals, num_classes, dropout_rate)
def VGG16(
image_channals: int,
num_classes: int,
dropout_rate: int = 0.5,
):
return VGGModel("16", image_channals, num_classes, dropout_rate)
def VGG19(
image_channals: int,
num_classes: int,
dropout_rate: int = 0.5,
):
return VGGModel("19", image_channals, num_classes, dropout_rate)