diff --git a/examples/mms/README.md b/examples/mms/README.md index 48460b375b..dfd8b85fbd 100644 --- a/examples/mms/README.md +++ b/examples/mms/README.md @@ -53,7 +53,8 @@ wget https://dl.fbaipublicfiles.com/mms/tts/azj-script_latin.tar.gz # North Azer Run this command to transcribe one or more audio files: ```shell command cd /path/to/fairseq-py/ -python examples/mms/asr/infer/mms_infer.py --model "/path/to/asr/model" --lang lang_code --audio "/path/to/audio_1.wav" "/path/to/audio_1.wav" +python examples/mms/asr/infer/mms_infer.py --model "/path/to/asr/model" --lang lang_code \ + --audio "/path/to/audio_1.wav" "/path/to/audio_2.wav" "/path/to/audio_3.wav" ``` For more advance configuration and calculate CER/WER, you could prepare manifest folder by creating a folder with this format: diff --git a/examples/mms/asr/infer/mms_infer.py b/examples/mms/asr/infer/mms_infer.py index f9e06cd222..65f86b3148 100644 --- a/examples/mms/asr/infer/mms_infer.py +++ b/examples/mms/asr/infer/mms_infer.py @@ -21,28 +21,38 @@ def parser(): parser.add_argument("--format", type=str, choices=["none", "letter"], default="letter") return parser.parse_args() +def reorder_decode(hypos): + outputs = [] + for hypo in hypos: + idx = int(re.findall("\(None-(\d+)\)$", hypo)[0]) + hypo = re.sub("\(\S+\)$", "", hypo).strip() + outputs.append((idx, hypo)) + outputs = sorted(outputs) + return outputs + def process(args): with tempfile.TemporaryDirectory() as tmpdir: print(">>> preparing tmp manifest dir ...", file=sys.stderr) tmpdir = Path(tmpdir) - with open(tmpdir / "dev.tsv", "w") as fw: + with open(tmpdir / "dev.tsv", "w") as fw, open(tmpdir / "dev.uid", "w") as fu: fw.write("/\n") for audio in args.audio: nsample = sf.SoundFile(audio).frames fw.write(f"{audio}\t{nsample}\n") - with open(tmpdir / "dev.uid", "w") as fw: - fw.write(f"{audio}\n"*len(args.audio)) + fu.write(f"{audio}\n") with open(tmpdir / "dev.ltr", "w") as fw: - fw.write("d u m m y | d u m m y\n"*len(args.audio)) + fw.write("d u m m y | d u m m y |\n"*len(args.audio)) with open(tmpdir / "dev.wrd", "w") as fw: fw.write("dummy dummy\n"*len(args.audio)) cmd = f""" - PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir} + PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=1440000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir} """ print(">>> loading model & running inference ...", file=sys.stderr) subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,) with open(tmpdir/"hypo.word") as fr: - for ii, hypo in enumerate(fr): + hypos = fr.readlines() + outputs = reorder_decode(hypos) + for ii, hypo in outputs: hypo = re.sub("\(\S+\)$", "", hypo).strip() print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')