-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
48 lines (39 loc) · 1.51 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from torch import nn
import resnet
from utils import _construct_depth_model
def generate_model(args):
if args.pre_train_model == False or args.mode == 'test':
print('Without Pre-trained model')
assert args.model_depth in [18, 50, 101]
if args.model_depth == 18:
model = resnet.resnet18(
output_dim=args.feature_dim,
sample_size=args.sample_size,
sample_duration=args.sample_duration,
shortcut_type=args.shortcut_type,
tracking=args.tracking,
pre_train=args.pre_train_model
)
elif args.model_depth == 50:
model = resnet.resnet50(
output_dim=args.feature_dim,
sample_size=args.sample_size,
sample_duration=args.sample_duration,
shortcut_type=args.shortcut_type,
tracking=args.tracking,
pre_train=args.pre_train_model
)
elif args.model_depth == 101:
model = resnet.resnet101(
output_dim=args.feature_dim,
sample_size=args.sample_size,
sample_duration=args.sample_duration,
shortcut_type=args.shortcut_type,
tracking=args.tracking,
pre_train=args.pre_train_model
)
model = nn.DataParallel(model, device_ids=None)
model = _construct_depth_model(model)
if args.use_cuda:
model = model.cuda()
return model