24
24
from torch .nn .utils .rnn import PackedSequence
25
25
from torch .nn .utils .rnn import pack_padded_sequence , pad_packed_sequence
26
26
27
- from baseline import parse_args , Seq2SeqModel , write_transcripts
27
+ from main import parse_args , Seq2SeqModel , write_transcripts
28
28
from model_utils import *
29
29
30
30
31
31
def main ():
32
32
args = parse_args ()
33
- args .cuda = not args .no_cuda and torch .cuda .is_available ()
34
33
35
34
t0 = time .time ()
36
35
@@ -41,16 +40,9 @@ def main():
41
40
pass
42
41
43
42
print ("Loading File Paths" )
44
- train_paths , dev_paths , test_paths = load_paths ()
45
- train_paths , dev_paths , test_paths = train_paths [:args .max_train ], dev_paths [:args .max_dev ], test_paths [:args .max_test ]
46
- t1 = time .time ()
47
- print_log ('%.2f Seconds' % (t1 - t0 ), LOG_PATH )
48
-
49
- print ("Loading Y Data" )
50
- test_paths = test_paths [:args .max_data ]
51
- train_ys = load_y_data ('train' ) # 1-dim np array of strings
52
- dev_ys = load_y_data ('dev' )
53
- test_ys = load_y_data ('test' )
43
+ train_ids , train_ys = load_fid_and_y_data ('train' )
44
+ dev_ids , dev_ys = load_fid_and_y_data ('dev' )
45
+ test_ids , test_ys = load_fid_and_y_data ('test' )
54
46
t1 = time .time ()
55
47
print_log ('%.2f Seconds' % (t1 - t0 ), LOG_PATH )
56
48
@@ -64,13 +56,13 @@ def main():
64
56
print ("Mapping Characters" )
65
57
testchars = map_characters (test_ys , charmap )
66
58
print ("Building Loader" )
67
- test_loader = make_loader (test_paths , testchars , args , shuffle = False , batch_size = 1 )
59
+ test_loader = make_loader (test_ids , testchars , args , shuffle = False , batch_size = 1 )
68
60
69
61
print ("Building Model" )
70
62
model = Seq2SeqModel (args , vocab_size = charcount , beam_width = args .beam_width )
71
63
72
64
CKPT_PATH = os .path .join (args .save_directory , 'model.ckpt' )
73
- if args .cuda :
65
+ if torch .cuda . is_available () :
74
66
model .load_state_dict (torch .load (CKPT_PATH ))
75
67
else :
76
68
gpu_dict = torch .load (CKPT_PATH , map_location = lambda storage , loc : storage )
@@ -80,7 +72,7 @@ def main():
80
72
model .load_state_dict (cpu_model_dict )
81
73
print ("Loaded Checkpoint" )
82
74
83
- if args .cuda :
75
+ if torch .cuda . is_available () :
84
76
model = model .cuda ()
85
77
86
78
model .eval ()
0 commit comments