-
Notifications
You must be signed in to change notification settings - Fork 4
/
test_time_adaptation.py
executable file
·73 lines (55 loc) · 2.46 KB
/
test_time_adaptation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
"""
Test time adpatation module
Editor: Marshall Xu
Last Edited: 22/10/2023
"""
import os
import config.adapt_config as adapt_config
from utils import preprocess_procedure
from utils import TTA_Training
args = adapt_config.args
ds_path = args.ds_path # path to original data
ps_path = args.ps_path # path to preprocessed data
out_path = args.out_path # path to infered data
if os.path.exists(out_path) == False:
print(f"{out_path} does not exist.")
os.mkdir(out_path)
print(f"{out_path} has been created!")
prep_mode = args.prep_mode # preprocessing mode
# when the preprocess is skipped,
# directly take the raw data for prediction
if prep_mode == 4:
ps_path = ds_path
px_path = args.px_path # path to proxies
if px_path == None: # when the proxy segmentation is not provided
px_path = os.path.join(out_path, "proxies", "")
if os.path.exists(px_path) == False:
print(f"{px_path} does not exist.")
os.mkdir(px_path) # create an intermediate output folder inside the output path
print(f"{px_path} has been created!")
assert os.path.exists(px_path) == True, "Container doesn't initialize properly, contact for maintenance: https://github.com/KMarshallX/vessel_code"
# output fintuned model path
out_mo_path = os.path.join(out_path, "finetuned", "")
if os.path.exists(out_mo_path) == False:
print(f"{out_mo_path} does not exist.")
os.mkdir(out_mo_path) # create an intermediate output folder inside the output path
print(f"{out_mo_path} has been created!")
assert os.path.exists(out_mo_path) == True, "Container doesn't initialize properly, contact for maintenance: https://github.com/KMarshallX/vessel_code"
# Resource optimization flag
resource_opt = args.resource
if __name__ == "__main__":
print("TTA session will start shortly..")
# preprocessing procedure
preprocess_procedure(ds_path, ps_path, prep_mode)
# initialize the tta process
tta_process = TTA_Training(args.loss_m, args.mo,
args.ic, args.oc, args.fil,
args.op, args.lr,
args.optim_gamma, args.ep,
args.batch_mul,
args.osz, args.aug_mode,
args.pretrained,
args.thresh, args.cc)
# tta procedure
tta_process.test_time_adaptation(ps_path, px_path, out_path, out_mo_path, resource_opt)