Skip to content

Commit 4fa30a1

Browse files
committed
new beam_las
1 parent f0ae86d commit 4fa30a1

File tree

3 files changed

+10
-32
lines changed

3 files changed

+10
-32
lines changed

tasks/Miami/beam_las/main.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def forward(self, utterances, utterance_lengths):
127127
sorted_lengths, order = torch.sort(utterance_lengths, 0, descending=True)
128128
_, backorder = torch.sort(order, 0)
129129
h = h[:, order, :]
130-
h = pack_padded_sequence(h, sorted_lengths) # .data.cpu().numpy())
130+
h = pack_padded_sequence(h, sorted_lengths.data.cpu().numpy())
131131

132132
# RNNs
133133
for rnn in self.rnns:
@@ -577,7 +577,7 @@ def main():
577577
print_log('%.2f Seconds' % (t1-t0), LOG_PATH)
578578

579579
print("Running")
580-
CKPT_PATH = os.path.join(args.save_directory, 'best_model.ckpt')
580+
CKPT_PATH = os.path.join(args.save_directory, 'model.ckpt')
581581
if os.path.exists(CKPT_PATH):
582582
model.load_state_dict(torch.load(CKPT_PATH))
583583
if torch.cuda.is_available():
@@ -643,8 +643,6 @@ def main():
643643
torch.save(model.state_dict(), CKPT_PATH)
644644
elif e - prev_best_epoch > args.patience:
645645
break
646-
torch.save(model.state_dict(), os.path.join(args.save_directory, f'model_{e}.ckpt'))
647-
print(f'Saved model epoch {e}')
648646
print_log('Val Loss: %f' % val_loss, LOG_PATH)
649647
print_log('Avg Val Perplexity: %f' % (tot_perp/len(train_loader.dataset)), LOG_PATH)
650648
cer_val = cer(args, model, dev_loader, charset, dev_ys, device=args.cuda)

tasks/Miami/beam_las/model_utils.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -229,21 +229,9 @@ def __init__(self, ids, labels=None):
229229
'''
230230
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
231231
self.mfcc_dir = os.path.join(parent_dir, 'data/mfcc')
232-
mfcc_files = os.listdir(self.mfcc_dir)
233-
mfcc_paths_set = set([os.path.join(self.mfcc_dir, f) for f in mfcc_files])
234232
self.ids = ids
235233
if labels:
236234
self.labels = [torch.from_numpy(y + 1).long() for y in labels] # +1 for start/end token
237-
new_ids = []
238-
new_labels = []
239-
for i, label in enumerate(self.labels):
240-
curr_id = self.ids[i]
241-
curr_mfcc_path = os.path.join(self.mfcc_dir, curr_id+'.mfcc')
242-
if curr_mfcc_path in mfcc_paths_set:
243-
new_ids.append(curr_id)
244-
new_labels.append(label)
245-
self.ids = new_ids
246-
self.labels = new_labels
247235
assert len(self.ids) == len(self.labels)
248236
else:
249237
self.labels = None
@@ -304,7 +292,7 @@ def make_loader(ids, labels, args, shuffle=True, batch_size=64):
304292
labels: list of 1-dim int np arrays
305293
'''
306294
# Build the DataLoaders
307-
kwargs = {'pin_memory': True, 'num_workers': args.num_workers} if args.cuda else {}
295+
kwargs = {'pin_memory': True, 'num_workers': args.num_workers} if torch.cuda.is_available() else {}
308296
dataset = ASRDataset(ids, labels)
309297
loader = DataLoader(dataset, collate_fn=speech_collate_fn, shuffle=shuffle, batch_size=batch_size, **kwargs)
310298
return loader

tasks/Miami/beam_las/test_model.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@
2424
from torch.nn.utils.rnn import PackedSequence
2525
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
2626

27-
from baseline import parse_args, Seq2SeqModel, write_transcripts
27+
from main import parse_args, Seq2SeqModel, write_transcripts
2828
from model_utils import *
2929

3030

3131
def main():
3232
args = parse_args()
33-
args.cuda = not args.no_cuda and torch.cuda.is_available()
3433

3534
t0 = time.time()
3635

@@ -41,16 +40,9 @@ def main():
4140
pass
4241

4342
print("Loading File Paths")
44-
train_paths, dev_paths, test_paths = load_paths()
45-
train_paths, dev_paths, test_paths = train_paths[:args.max_train], dev_paths[:args.max_dev], test_paths[:args.max_test]
46-
t1 = time.time()
47-
print_log('%.2f Seconds' % (t1-t0), LOG_PATH)
48-
49-
print("Loading Y Data")
50-
test_paths = test_paths[:args.max_data]
51-
train_ys = load_y_data('train') # 1-dim np array of strings
52-
dev_ys = load_y_data('dev')
53-
test_ys = load_y_data('test')
43+
train_ids, train_ys = load_fid_and_y_data('train')
44+
dev_ids, dev_ys = load_fid_and_y_data('dev')
45+
test_ids, test_ys = load_fid_and_y_data('test')
5446
t1 = time.time()
5547
print_log('%.2f Seconds' % (t1-t0), LOG_PATH)
5648

@@ -64,13 +56,13 @@ def main():
6456
print("Mapping Characters")
6557
testchars = map_characters(test_ys, charmap)
6658
print("Building Loader")
67-
test_loader = make_loader(test_paths, testchars, args, shuffle=False, batch_size=1)
59+
test_loader = make_loader(test_ids, testchars, args, shuffle=False, batch_size=1)
6860

6961
print("Building Model")
7062
model = Seq2SeqModel(args, vocab_size=charcount, beam_width=args.beam_width)
7163

7264
CKPT_PATH = os.path.join(args.save_directory, 'model.ckpt')
73-
if args.cuda:
65+
if torch.cuda.is_available():
7466
model.load_state_dict(torch.load(CKPT_PATH))
7567
else:
7668
gpu_dict = torch.load(CKPT_PATH, map_location=lambda storage, loc: storage)
@@ -80,7 +72,7 @@ def main():
8072
model.load_state_dict(cpu_model_dict)
8173
print("Loaded Checkpoint")
8274

83-
if args.cuda:
75+
if torch.cuda.is_available():
8476
model = model.cuda()
8577

8678
model.eval()

0 commit comments

Comments
 (0)