-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_net.py
63 lines (43 loc) · 3.83 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from end2end import End2End
import argparse
from config import Config
def str2bool(v):
return v.lower() in ("yes", "true", "t", "1")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--GPU', type=int, default=2, required=True, help="ID of GPU used for training/evaluation.")
parser.add_argument('--RUN_NAME', type=str, default='runs', required=True, help="Name of the run, used as directory name for storing results.")
parser.add_argument('--DATASET', type=str, default='KSDD', required=True, help="Which dataset to use.")
parser.add_argument('--DATASET_PATH', type=str, default='./datasets/KSDD/', required=True, help="Path to the dataset.")
parser.add_argument('--EPOCHS', type=int, default=100, required=True, help="Number of training epochs.")
parser.add_argument('--LEARNING_RATE', type=float, default=1.0, required=True, help="Learning rate.")
parser.add_argument('--DELTA_CLS_LOSS', type=float, default=0.01, required=True, help="Weight delta for classification loss.")
parser.add_argument('--BATCH_SIZE', type=int, default=16, required=True, help="Batch size for training.")
parser.add_argument('--WEIGHTED_SEG_LOSS', type=str2bool, default=True, required=True, help="Whether to use weighted segmentation loss.")
parser.add_argument('--WEIGHTED_SEG_LOSS_P', type=float, default=2, required=False, help="Degree of polynomial for weighted segmentation loss.")
parser.add_argument('--WEIGHTED_SEG_LOSS_MAX', type=float, default=1, required=False, help="Scaling factor for weighted segmentation loss.")
parser.add_argument('--DYN_BALANCED_LOSS', type=str2bool, default=True, required=True, help="Whether to use dynamically balanced loss.")
parser.add_argument('--GRADIENT_ADJUSTMENT', type=str2bool, default=True, required=True, help="Whether to use gradient adjustment.")
parser.add_argument('--FREQUENCY_SAMPLING', type=str2bool, default=True, required=False, help="Whether to use frequency-of-use based sampling.")
parser.add_argument('--DILATE', type=int, default=7, required=False, help="Size of dilation kernel for labels")
parser.add_argument('--FOLD', type=int, default=0, help="Which fold (KSDD) or class (DAGM) to train.")
parser.add_argument('--TRAIN_NUM', type=int, default=33, help="Number of positive training samples for KSDD or STEEL.")
parser.add_argument('--NUM_SEGMENTED', type=int, default=33, required=True, help="Number of segmented positive samples.")
parser.add_argument('--RESULTS_PATH', type=str, default='./results/', help="Directory to which results are saved.")
parser.add_argument('--VALIDATE', type=str2bool, default=True, help="Whether to validate during training.")
parser.add_argument('--VALIDATE_ON_TEST', type=str2bool, default=True, help="Whether to validate on test set.")
parser.add_argument('--VALIDATION_N_EPOCHS', type=int, default=8, help="Number of epochs between consecutive validation runs.")
parser.add_argument('--USE_BEST_MODEL', type=str2bool, default=True, help="Whether to use the best model according to validation metrics for evaluation.")
parser.add_argument('--ON_DEMAND_READ', type=str2bool, default=None, help="Whether to use on-demand read of data from disk instead of storing it in memory.")
parser.add_argument('--REPRODUCIBLE_RUN', type=str2bool, default=None, help="Whether to fix seeds and disable CUDA benchmark mode.")
parser.add_argument('--MEMORY_FIT', type=int, default=None, help="How many images can be fitted in GPU memory.")
parser.add_argument('--SAVE_IMAGES', type=str2bool, default=True, help="Save test images or not.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
configuration = Config()
configuration.merge_from_args(args)
configuration.init_extra()
end2end = End2End(cfg=configuration)
end2end.train()