-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_selector.py
66 lines (61 loc) · 2.53 KB
/
model_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from models.multi_lenet import MultiLeNetO, MultiLeNetR
from models.segnet import SegnetEncoder, SegnetInstanceDecoder, SegnetSegmentationDecoder, SegnetDepthDecoder
from models.pspnet import SegmentationDecoder, get_segmentation_encoder
from models.multi_faces_resnet import ResNet, FaceAttributeDecoder, BasicBlock
import torchvision.models as model_collection
import torch.nn as nn
def get_model(params):
data = params['dataset']
if 'mnist' in data:
model = {}
model['rep'] = MultiLeNetR()
if params['parallel']:
model['rep'] = nn.DataParallel(model['rep'])
model['rep'].cuda()
if 'L' in params['tasks']:
model['L'] = MultiLeNetO()
if params['parallel']:
model['L'] = nn.DataParallel(model['L'])
model['L'].cuda()
if 'R' in params['tasks']:
model['R'] = MultiLeNetO()
if params['parallel']:
model['R'] = nn.DataParallel(model['R'])
model['R'].cuda()
return model
if 'cityscapes' in data:
model = {}
model['rep'] = get_segmentation_encoder() # SegnetEncoder()
#vgg16 = model_collection.vgg16(pretrained=True)
#model['rep'].init_vgg16_params(vgg16)
if params['parallel']:
model['rep'] = nn.DataParallel(model['rep'])
model['rep'].cuda()
if 'S' in params['tasks']:
model['S'] = SegmentationDecoder(num_class=19, task_type='C')
if params['parallel']:
model['S'] = nn.DataParallel(model['S'])
model['S'].cuda()
if 'I' in params['tasks']:
model['I'] = SegmentationDecoder(num_class=2, task_type='R')
if params['parallel']:
model['R'] = nn.DataParallel(model['R'])
model['I'].cuda()
if 'D' in params['tasks']:
model['D'] = SegmentationDecoder(num_class=1, task_type='R')
if params['parallel']:
model['D'] = nn.DataParallel(model['D'])
model['D'].cuda()
return model
if 'celeba' in data:
model = {}
model['rep'] = ResNet(BasicBlock, [2,2,2,2])
if params['parallel']:
model['rep'] = nn.DataParallel(model['rep'])
model['rep'].cuda()
for t in params['tasks']:
model[t] = FaceAttributeDecoder()
if params['parallel']:
model[t] = nn.DataParallel(model[t])
model[t].cuda()
return model