forked from aatifjiwani/rgb-footprint-extract
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_deeplab.py
290 lines (238 loc) · 13.9 KB
/
run_deeplab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# We would like to thank and acknowledge jfzhang95 for the DeepLabV3+ module
# as well as a template for metrics and the training pipeline.
# His code repository can be found here:
# https://github.com/jfzhang95/pytorch-deeplab-xception
import argparse
from PIL import Image
from models.deeplab.train import *
from models.deeplab.evaluate import *
def main():
parser = argparse.ArgumentParser(description="DeeplabV3+ And Evaluation")
# model parameters
parser.add_argument('--backbone', type=str, default='resnet',
choices=['resnet', 'xception', 'drn', 'mobilenet', 'drn_c42'],
help='backbone name (default: resnet)')
parser.add_argument('--out-stride', type=int, default=16,
help='network output stride (default: 8)')
parser.add_argument('--dataset', type=str, default='urban3d',
choices=['urban3d', 'spaceNet', 'crowdAI', 'combined', 'OSM', 'combined_naip', 'OSM_split4', 'OSM_imonly'],
help='dataset name (default: urban3d)')
parser.add_argument('--data-root', type=str, default='/data/',
help='datasets root path')
parser.add_argument('--workers', type=int, default=4,
metavar='N', help='dataloader threads')
parser.add_argument('--sync-bn', type=bool, default=None,
help='whether to use sync bn (default: auto)')
parser.add_argument('--freeze-bn', action='store_true', default=False,
help='whether to freeze bn parameters (default: False)')
parser.add_argument('--loss-type', type=str, default='ce_dice',
choices=['ce', 'ce_dice', 'wce_dice'],
help='loss func type (default: ce)')
parser.add_argument('--fbeta', type=float, default=1, help='beta for FBeta-Measure')
# parser.add_argument('--loss-weights', type=float, nargs="+", default=[1.0, 1.0],
# help='loss weighting')
parser.add_argument('--loss-weights', type=str, default='1.0,1.0',
help='loss weighting')
parser.add_argument("--num-classes", type=int, default=2,
help='number of classes to predict (2 for binary mask)')
# parser.add_argument('--dropout', type=float, nargs="+", default=[0.1, 0.5],
# help='dropout values')
parser.add_argument('--dropout', type=str, default='0.1,0.5',
help='dropout values')
parser.add_argument('--preempt-robust', action='store_true', default=False,
help='True if you want the model to find the latest checkpoint before loading in \
resume checkpoint. Helpful when SLURM pre-empts and stops the job and you don\'t want to restart from scratch')
# training hyper params
parser.add_argument('--epochs', type=int, default=None, metavar='N',
help='number of epochs to train (default: auto)')
parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='start epochs (default:0)')
parser.add_argument('--batch-size', type=int, default=None,
metavar='N', help='input batch size for \
training (default: auto)')
parser.add_argument('--test-batch-size', type=int, default=None,
metavar='N', help='input batch size for \
testing (default: auto)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
help='learning rate (default: auto)')
parser.add_argument('--loss-weights-param', type=float, default=1.01,
help='base of exponential function in defining loss weights')
# optimizer params
parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=5e-4,
metavar='M', help='w-decay (default: 5e-4)')
parser.add_argument('--nesterov', action='store_true', default=False,
help='whether use nesterov (default: False)')
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true', default=
False, help='disables CUDA training')
parser.add_argument('--gpu-ids', type=str, default='0',
help='use which gpu to train, must be a \
comma-separated list of integers only (default=0)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
# name
parser.add_argument('--checkname', type=str, default=None,
help='set the checkpoint name')
parser.add_argument('--checkname-add', type=str, default=None,
help='set the checkpoint name')
# evaluation option
parser.add_argument('--no-val', action='store_true', default=False,
help='skip validation during training')
parser.add_argument('--use-wandb', action='store_true', default=False)
parser.add_argument('--resume', type=str, default=None, help='experiment to load')
parser.add_argument("--evaluate", action='store_true', default=False)
parser.add_argument('--best-miou', action='store_true', default=False)
# inference options (includes some evaluation options)
parser.add_argument('--inference', action='store_true', default=False)
parser.add_argument('--input-filename', type=str, default=None, help='path to an input file to run inference on')
parser.add_argument('--output-filename', type=str, default=None, help='path to where predicted segmentation mask will be written')
parser.add_argument('--window-size', type=int, default=None, help="the size of grid blocks to sample from the input, use if encountering OOM issues")
parser.add_argument('--stride', type=int, default=None, help="the stride at which to sample grid blocks, recommended value is equal to `window_size`")
parser.add_argument('--minference', action='store_true', default=False)
parser.add_argument('--output-dir', type=str, default=None,
help='path to where multiple predicted segmentation mask will be written')
#boundaries
parser.add_argument('--incl-bounds', action='store_true', default=False,
help='includes boundaries of masks in loss function')
parser.add_argument('--bounds-kernel-size', type=int, default=3,
help='kernel size for calculating boundary')
# misc
parser.add_argument('--owner', type=str, default=None, help='N or A to indicate who\'s running')
parser.add_argument('--superres', type=int, default=None,
help='whether to use the superres imagery or not')
parser.add_argument('--year', type=int, default=None,
help='NAIP year for inference')
args = parser.parse_args()
run_deeplab(args)
def run_deeplab(args):
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
try:
args.loss_weights = [float(s) for s in args.loss_weights.split(',')]
except ValueError:
raise ValueError('Argument --loss_weights must be a comma-separated list of 2 floats')
try:
args.dropout = [float(s) for s in args.dropout.split(',')]
except ValueError:
raise ValueError('Argument --dropout must be a comma-separated list of 2 floats')
assert len(args.loss_weights) == 2
assert len(args.dropout) == 2
if args.sync_bn is None:
if args.cuda and len(args.gpu_ids) > 1:
args.sync_bn = True
else:
args.sync_bn = False
# default settings for epochs, batch_size and lr
if args.epochs is None:
raise ValueError("epochs must be specified")
if args.batch_size is None:
args.batch_size = 4 * len(args.gpu_ids)
if args.test_batch_size is None:
args.test_batch_size = args.batch_size
if args.checkname is None:
dic = {'los_angeles': 'LA', 'san_jose': 'SJ'}
loc = ''
for k, v in dic.items():
if k in args.data_root:
loc = f'{v}_'
else:
loc = 'SJ+LA_'
args.checkname = loc + f'{args.fbeta}_{args.freeze_bn}_{args.lr}_{args.weight_decay}_{args.loss_weights_param}_{args.batch_size}'
if args.superres is not None:
args.checkname += f'_superresx{args.superres}'
if args.year is not None:
args.checkname += f'_{args.year}'
if args.checkname_add is not None:
args.checkname += f'_{args.checkname_add}'
# if args.checkname_add is None:
# if 'los_angeles' in args.data_root:
# args.checkname = f'LA_{args.dataset}_{args.fbeta}_{args.freeze_bn}_{args.lr}_{args.weight_decay}_{args.loss_weights_param}_{args.resume}'
# elif 'san_jose' in args.data_root:
# args.checkname = f'SJ_{args.dataset}_{args.fbeta}_{args.freeze_bn}_{args.lr}_{args.weight_decay}_{args.loss_weights_param}_{args.resume}'
# else:
# args.checkname = f'{args.dataset}_{args.fbeta}_{args.freeze_bn}_{args.lr}_{args.weight_decay}_{args.loss_weights_param}_{args.resume}'
# else:
# if 'los_angeles' in args.data_root:
# args.checkname = f'LA_{args.dataset}_{args.fbeta}_{args.freeze_bn}_{args.lr}_{args.weight_decay}_{args.loss_weights_param}_{args.resume}_{args.checkname_add}'
# elif 'san_jose' in args.data_root:
# args.checkname = f'SJ_{args.dataset}_{args.fbeta}_{args.freeze_bn}_{args.lr}_{args.weight_decay}_{args.loss_weights_param}_{args.resume}_{args.checkname_add}'
# else:
# args.checkname = f'{args.dataset}_{args.fbeta}_{args.freeze_bn}_{args.lr}_{args.weight_decay}_{args.loss_weights_param}_{args.resume}_{args.checkname_add}'
if args.preempt_robust:
checkpoint_name = "most_recent_epoch_checkpoint.pth.tar"
# check if checkname path exists
if os.path.exists(os.path.join('/oak/stanford/groups/deho/building_compliance/rgb-footprint-extract/weights', args.checkname, checkpoint_name)):
# if it does, resume from checkname. allows us to automatically restart our training job if slurm preempts
args.resume = os.path.join(args.checkname, checkpoint_name)
torch.manual_seed(args.seed)
if args.inference:
handle_inference(args)
elif args.evaluate:
handle_evaluate(args)
elif args.minference:
handle_multiple_inference(args)
else:
handle_training(args)
def handle_inference(args):
# Validate arguments
input_formats, output_formats = {".npy": "numpy"}, [".npy", ".png", ".tiff"]
get_ext = lambda filename: os.path.splitext(filename)[-1] if filename else None
input_ext, output_ext = get_ext(args.input_filename), get_ext(args.output_filename)
assert args.input_filename and input_ext in input_formats, f"Accepted input file formats: {input_formats.keys()}"
assert args.output_filename and output_ext in output_formats, f"Accepted output formats: {output_formats}"
if args.window_size or args.stride:
assert args.window_size and args.stride, "Both `window_size` and `stride` must be set."
args.dataset = input_formats[os.path.splitext(args.input_filename)[-1]]
args.test_batch_size = 1
tester = Tester(args)
print("Inference starting on {}...".format(args.input_filename))
final_output = tester.infer()
assert len(final_output.shape) == 2
if output_ext == ".png":
Image.fromarray((final_output*255)).save(args.output_filename)
elif output_ext == ".npy":
np.save(args.output_filename, final_output)
elif output_ext == ".tiff":
raise NotImplementedError("TIFF output support is coming soon.")
def handle_multiple_inference(args):
# Validate arguments
input_formats, output_formats = {".npy": "numpy"}, [".npy", ".png", ".tiff"]
#get_ext = lambda filename: os.path.splitext(filename)[-1] if filename else None
#input_ext, output_ext = get_ext(args.input_filename), get_ext(args.output_filename)
#assert args.input_filename and input_ext in input_formats, f"Accepted input file formats: {input_formats.keys()}"
#assert args.output_filename and output_ext in output_formats, f"Accepted output formats: {output_formats}"
if args.window_size or args.stride:
assert args.window_size and args.stride, "Both `window_size` and `stride` must be set."
#args.dataset = input_formats[os.path.splitext(args.input_filename)[-1]]
args.test_batch_size = 1
if args.dataset == 'OSM_imonly':
tester = Tester1(args, '') # partition doesn't matter because we're just going to read from args.data_root
tester.infer_multiple()
else:
for i in ['train', 'val', 'test']:
tester = Tester1(args, i)
tester.infer_multiple()
def handle_evaluate(args):
tester = Tester(args)
print("Experiment {} instantiated. Evaluation starting...".format(args.checkname))
tester.test()
def handle_training(args):
trainer = Trainer(args)
print("Learning rate: {}; L2 factor: {}".format(args.lr, args.weight_decay))
print("Experiment {} instantiated. Training starting...".format(args.checkname))
# NEW
print("Starting from epoch {}".format(trainer.start_epoch))
print("Training for {} epochs".format(trainer.args.epochs))
print("Batch size: {}; Test Batch Size: {}".format(args.batch_size, args.test_batch_size))
for epoch in range(trainer.start_epoch, trainer.start_epoch+trainer.args.epochs):
trainer.training(epoch)
if not trainer.args.no_val:
trainer.validation(epoch)
if __name__ == "__main__":
main()