Skip to content

Commit

Permalink
Update to have code running
Browse files Browse the repository at this point in the history
  • Loading branch information
William-N-Havard committed Apr 17, 2024
1 parent 52c1cab commit 749f3fe
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 37 deletions.
2 changes: 1 addition & 1 deletion examples/mms/asr/config/infer_common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ hydra:
run:
dir: ${common_eval.results_path}/${dataset.gen_subset}
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
dir: /tmp/${env:USER}/${env:PREFIX}/${common_eval.results_path}
subdir: ${dataset.gen_subset}
dataset:
max_tokens: 2_000_000
Expand Down
49 changes: 25 additions & 24 deletions examples/mms/asr/infer/mms_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,31 @@ def reorder_decode(hypos):
return outputs

def process(args):
with tempfile.TemporaryDirectory() as tmpdir:
print(">>> preparing tmp manifest dir ...", file=sys.stderr)
tmpdir = Path(tmpdir)
with open(tmpdir / "dev.tsv", "w") as fw, open(tmpdir / "dev.uid", "w") as fu:
fw.write("/\n")
for audio in args.audio:
nsample = sf.SoundFile(audio).frames
fw.write(f"{audio}\t{nsample}\n")
fu.write(f"{audio}\n")
with open(tmpdir / "dev.ltr", "w") as fw:
fw.write("d u m m y | d u m m y |\n"*len(args.audio))
with open(tmpdir / "dev.wrd", "w") as fw:
fw.write("dummy dummy\n"*len(args.audio))
cmd = f"""
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=1440000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir} {args.extra_infer_args}
"""
print(">>> loading model & running inference ...", file=sys.stderr)
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
with open(tmpdir/"hypo.word") as fr:
hypos = fr.readlines()
outputs = reorder_decode(hypos)
for ii, hypo in outputs:
hypo = re.sub("\(\S+\)$", "", hypo).strip()
print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')
tmpdir = Path.home().joinpath("TEMP")
tmpdir.mkdir(exist_ok=True)
print(">>> preparing tmp manifest dir ...", file=sys.stderr)
with open(tmpdir / f"{args.lang}:dev.tsv", "w") as fw, open(tmpdir / "dev.uid", "w") as fu:
fw.write("/\n")
for audio in args.audio:
nsample = sf.SoundFile(audio).frames
fw.write(f"{audio}\t{nsample}\n")
fu.write(f"{audio}\n")
with open(tmpdir / f"{args.lang}:dev.ltr", "w") as fw:
fw.write("d u m m y | d u m m y |\n"*len(args.audio))
with open(tmpdir / f"{args.lang}:dev.wrd", "w") as fw:
fw.write("dummy dummy\n"*len(args.audio))
cmd = f"""
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=1440000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir} {args.extra_infer_args}
"""
print(cmd)
print(">>> loading model & running inference ...", file=sys.stderr)
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
with open(tmpdir/"hypo.word") as fr:
hypos = fr.readlines()
outputs = reorder_decode(hypos)
for ii, hypo in outputs:
hypo = re.sub("\(\S+\)$", "", hypo).strip()
print(f'{args.audio[ii]}\t{hypo}')


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/mms/data_prep/align_and_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import argparse


from examples.mms.data_prep.text_normalization import text_normalize
from examples.mms.data_prep.align_utils import (
from text_normalization import text_normalize
from align_utils import (
get_uroman_tokens,
time_to_frame,
load_model_dict,
Expand Down
2 changes: 1 addition & 1 deletion examples/mms/data_prep/text_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import unicodedata

from examples.mms.data_prep.norm_config import norm_config
from norm_config import norm_config


def text_normalize(text, iso_code, lower_case=True, remove_numbers=True, remove_brackets=False):
Expand Down
15 changes: 9 additions & 6 deletions examples/speech_recognition/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,19 @@ def main(args, task=None, model_state=None):
def build_generator(args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

from w2l_decoder import W2lViterbiDecoder
#from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

return W2lViterbiDecoder(args, task.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

from w2l_decoder import W2lKenLMDecoder
#from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

return W2lKenLMDecoder(args, task.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

from w2l_decoder import W2lFairseqLMDecoder
#from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

return W2lFairseqLMDecoder(args, task.target_dictionary)
else:
print(
Expand Down
7 changes: 5 additions & 2 deletions examples/speech_recognition/w2l_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from data.replabels import unpack_replabels
#from examples.speech_recognition.data.replabels import unpack_replabels
from data.replabels import unpack_replabels
from data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
from omegaconf import open_dict
Expand All @@ -44,7 +47,7 @@
)
LM = object
LMState = object

exit()

class W2lDecoder(object):
def __init__(self, args, tgt_dict):
Expand Down
4 changes: 3 additions & 1 deletion fairseq/criterions/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(self, model, sample, reduce=True, **kwargs):

c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)

print((pred_units_arr, targ_units_arr))
targ_words = post_process(targ_units, self.post_process).split()

pred_units = self.task.target_dictionary.string(pred_units_arr)
Expand All @@ -239,8 +239,10 @@ def forward(self, model, sample, reduce=True, **kwargs):
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
print(pred_words, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
print(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist

Expand Down
3 changes: 3 additions & 0 deletions fairseq/tasks/audio_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,13 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
**mask_args,
)
else:
print(task_cfg.multi_corpus_keys)

dataset_map = OrderedDict()
self.dataset_map = {}
multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")]
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}
print(corpus_idx_map)
data_keys = [k.split(":") for k in split.split(",")]

multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")]
Expand Down

0 comments on commit 749f3fe

Please sign in to comment.