forked from SeanNaren/deepspeech.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
transcribe.py
80 lines (64 loc) · 2.9 KB
/
transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import warnings
from opts import add_decoder_args, add_inference_args
from utils import load_model
warnings.simplefilter('ignore')
from decoder import GreedyDecoder
import torch
from data.data_loader import SpectrogramParser
from model import DeepSpeech
import os.path
import json
def decode_results(model, decoded_output, decoded_offsets):
results = {
"output": [],
"_meta": {
"acoustic_model": {
"name": os.path.basename(args.model_path)
},
"language_model": {
"name": os.path.basename(args.lm_path) if args.lm_path else None,
},
"decoder": {
"lm": args.lm_path is not None,
"alpha": args.alpha if args.lm_path is not None else None,
"beta": args.beta if args.lm_path is not None else None,
"type": args.decoder,
}
}
}
for b in range(len(decoded_output)):
for pi in range(min(args.top_paths, len(decoded_output[b]))):
result = {'transcription': decoded_output[b][pi]}
if args.offsets:
result['offsets'] = decoded_offsets[b][pi].tolist()
results['output'].append(result)
return results
def transcribe(audio_path, parser, model, decoder, device):
spect = parser.parse_audio(audio_path).contiguous()
spect = spect.view(1, 1, spect.size(0), spect.size(1))
spect = spect.to(device)
input_sizes = torch.IntTensor([spect.size(3)]).int()
out, output_sizes = model(spect, input_sizes)
decoded_output, decoded_offsets = decoder.decode(out, output_sizes)
return decoded_output, decoded_offsets
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DeepSpeech transcription')
parser = add_inference_args(parser)
parser.add_argument('--audio-path', default='audio.wav',
help='Audio file to predict on')
parser.add_argument('--offsets', dest='offsets', action='store_true', help='Returns time offset information')
parser = add_decoder_args(parser)
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
model = load_model(device, args.model_path, args.cuda)
if args.decoder == "beam":
from decoder import BeamCTCDecoder
decoder = BeamCTCDecoder(model.labels, lm_path=args.lm_path, alpha=args.alpha, beta=args.beta,
cutoff_top_n=args.cutoff_top_n, cutoff_prob=args.cutoff_prob,
beam_width=args.beam_width, num_processes=args.lm_workers)
else:
decoder = GreedyDecoder(model.labels, blank_index=model.labels.index('_'))
parser = SpectrogramParser(model.audio_conf, normalize=True)
decoded_output, decoded_offsets = transcribe(args.audio_path, parser, model, decoder, device)
print(json.dumps(decode_results(model, decoded_output, decoded_offsets)))