forked from valhongli/reID-PCB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
83 lines (81 loc) · 2.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import sys
import argparse
from train import train
from test import test
import torch
def main(args):
torch.manual_seed(960202)
if args.stage == 'all' or args.stage == 'train':
train(args)
if args.stage == 'all' or args.stage == 'test':
test(args)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Person Re-Identification Reproduce'
)
parser.add_argument('--params-filename',
type=str, default='reid.pth.tar',
help='filename of model parameters.'
)
parser.add_argument('--use-gpu',
type=int, default=1,
help='set 1 if want to use GPU, otherwise 0. (default 1)'
)
parser.add_argument('--world-size',
type=int, default=1,
help='number of distributed processes. (default 1)'
)
parser.add_argument('--dist-url',
type=str, default='tcp://127.0.0.1:2222',
help='the master-node\'s address and port'
)
parser.add_argument('--dist-rank',
type=int, default=0,
help='rank of distributed process. (default 0)'
)
parser.add_argument('--last-conv',
type=int, default=1,
help='whether contains last convolution layter. (default 1)'
)
parser.add_argument('--batch-size',
type=int, default=64,
help='training data batch size. (default 64)'
)
parser.add_argument('--num-workers',
type=int, default=20,
help='number of workers when loading data. (default 20)'
)
parser.add_argument('--load-once',
type=int, default=0,
help='load all of data at once. (default 0)'
)
parser.add_argument('--epoch',
type=int, default=60,
help="number of epochs. (default 60)"
)
parser.add_argument('--stage',
type=str, default='train',
help='running stage. train, test or all. (default train)'
)
parser.add_argument('--test-type',
type=str, default='pcb',
help='model type when testing. pcb, rpp or fnl. (default pcb)'
)
parser.add_argument('--rpp-std',
type=float, default=0.01,
help='standard deviation of initialization of rpp layer. (default 0.01)'
)
parser.add_argument('--conv-std',
type=float, default=0.001,
help='standard deviation of initialization of conv layer. (default 0.001)'
)
args = parser.parse_args()
args.use_gpu = args.use_gpu == 1
args.last_conv = args.last_conv == 1
args.load_once = args.load_once == 1
args.distributed = args.world_size > 1
args.home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
args.dataset = os.path.join(args.home, 'datasets')
args.model_file = os.path.join(args.home, 'models', args.params_filename)
main(args)