Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xungeer29 committed Apr 7, 2019
1 parent 93add71 commit 7d542eb
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 10 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
class DefaultConfigs(object):
data_root = '/media/gfx/data1/DATA/lida/UCMerced_LandUse/UCMerced_LandUse/Images' # 数据集的根目录
model = 'ResNet152' # ResNet34 使用的模型
model = 'ResNet152' # ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 使用的模型
freeze = True # 是否冻结卷基层

seed = 1000 # 固定随机种子
Expand Down
58 changes: 49 additions & 9 deletions networks/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,47 @@ def l2_norm(x):
x = torch.div(x, norm)
return x

class ResNet18(nn.Module):
def __init__(self, model, num_classes=1000):
super(ResNet18, self).__init__()
self.backbone = model

self.fc1 = nn.Linear(2048, 1024)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(1024, num_classes)

def forward(self, x):
# x = whitening(x)
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)

x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)

x = self.backbone.avgpool(x)

x = x.view(x.size(0), -1)
x = l2_norm(x)
x = self.dropout(x)
x = self.fc1(x)
x = l2_norm(x)
x = self.dropout(x)
x = self.fc2(x)

return x

class ResNet34(nn.Module):
def __init__(self, model, num_classes=1000):
super(ResNet34, self).__init__()
self.backbone = model

self.fc1 = nn.Linear(8192, 2048)
self.fc1 = nn.Linear(2048, 1024)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(2048, num_classes)
self.fc2 = nn.Linear(1024, num_classes)

def forward(self, x):
# x = whitening(x)
Expand Down Expand Up @@ -57,8 +90,9 @@ def __init__(self, model, num_classes=1000):
super(ResNet50, self).__init__()
self.backbone = model

self.fc1 = nn.Linear(8192, 2048)
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(2048, num_classes)
self.fc2 = nn.Linear(2048, num_classes)


def forward(self, x):
Expand All @@ -78,7 +112,10 @@ def forward(self, x):
x = x.view(x.size(0), -1)
x = l2_norm(x)
x = self.dropout(x)
x = self.fc(x)
x = self.fc1(x)
x = l2_norm(x)
x = self.dropout(x)
x = self.fc2(x)

return x

Expand All @@ -88,9 +125,9 @@ def __init__(self, model, num_classes=1000):
super(ResNet101, self).__init__()
self.backbone = model

self.fc1 = nn.Linear(8192, 2048)
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(2048, num_classes)

self.fc2 = nn.Linear(2048, num_classes)

def forward(self, x):
#x = whitening(x)
Expand All @@ -109,7 +146,10 @@ def forward(self, x):
x = x.view(x.size(0), -1)
x = l2_norm(x)
x = self.dropout(x)
x = self.fc(x)
x = self.fc1(x)
x = l2_norm(x)
x = self.dropout(x)
x = self.fc2(x)

return x

Expand Down Expand Up @@ -148,8 +188,8 @@ def forward(self, x):
return x

if __name__ == '__main__':
backbone = models.resnet152(pretrained=True)
models = ResNet152(backbone, 21)
backbone = models.resnet101(pretrained=True)
models = ResNet101(backbone, 21)
data = torch.randn(1, 3, 256, 256)
x = models(data)
#print(x)
Expand Down
6 changes: 6 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def train():
elif config.model == 'ResNet34':
backbone = models.resnet34(pretrained=True)
model = ResNet34(backbone, num_classes=config.num_classes)
elif config.model == 'ResNet50':
backbone = models.resnet50(pretrained=True)
model = ResNet50(backbone, num_classes=config.num_classes)
elif config.model == 'ResNet101':
backbone = models.resnet101(pretrained=True)
model = ResNet101(backbone, num_classes=config.num_classes)
elif config.model == 'ResNet152':
backbone = models.resnet152(pretrained=True)
model = ResNet152(backbone, num_classes=config.num_classes)
Expand Down

0 comments on commit 7d542eb

Please sign in to comment.