diff --git a/.gitignore b/.gitignore index 6606d1085..1e0667fa7 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,8 @@ htmlcov/ .coverage .coverage.* .cache +cache/ +ASR-cv* nosetests.xml coverage.xml *.cover diff --git a/README.md b/README.md index a6defc05b..fc0b33c4d 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The SpeechBrain Benchmarks currently include the following: - [MOABB](https://github.com/speechbrain/benchmarks/tree/main/benchmarks/MOABB) - A benchmark designed for evaluating neural models in well-known EEG tasks like motor imagery, P300, and SSVEP. -- [DASB](https://github.com/speechbrain/benchmarks/tree/main/benchmarks/DASB) - A benchmark designed for evaluating discrete audio tokens across a wide range of discriminative +- [DASB](https://github.com/speechbrain/benchmarks/tree/DASB/benchmarks/DASB) - A benchmark designed for evaluating discrete audio tokens across a wide range of discriminative and generative tasks. diff --git a/benchmarks/DASB/CommonVoice/ASR/LSTM/common_voice_prepare.py b/benchmarks/DASB/CommonVoice/ASR/LSTM/common_voice_prepare.py deleted file mode 120000 index 402027afc..000000000 --- a/benchmarks/DASB/CommonVoice/ASR/LSTM/common_voice_prepare.py +++ /dev/null @@ -1 +0,0 @@ -../../common_voice_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_speech_tokenizer.yaml b/benchmarks/DASB/CommonVoice/ASR/LSTM/hparams/train.yaml similarity index 55% rename from benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_speech_tokenizer.yaml rename to benchmarks/DASB/CommonVoice/ASR/LSTM/hparams/train.yaml index eda9a2bad..a0c0600fd 100644 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_speech_tokenizer.yaml +++ b/benchmarks/DASB/CommonVoice/ASR/LSTM/hparams/train.yaml @@ -1,6 +1,8 @@ # ################################ -# Recipe for training an discrete-input ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. +# Script for training an ASR model evaluating an SSL representation +# model on one language from the CommonVoice dataset. A SentencePiece tokenizer +# with number of tokens equal to is learned in a first phase +# on the considered language. # # Authors # * Pooneh Mousavi 2024 @@ -9,69 +11,100 @@ # Seed needs to be set at top of yaml, before objects with parameters are made seed: 1986 __set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-LSTM/speech_tokenizer/ -output_wer_folder: !ref / +language: cy # use 'cy' for Welsh and 'eu' for Basque +output_folder: !ref results/CommonVoice/speech_tokenizer// +test_wer_file: !ref /wer_test.txt save_folder: !ref /save train_log: !ref /train_log.txt - +cached_data_folder: cache/CommonVoice//LSTM/speech_tokenizer/ +run_name: !PLACEHOLDER # Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] - -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - +data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/ +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation +testing: True # If set to True, the test evlaution is done, otherwise skipped. + +tokens_folder: !PLACEHOLDER # Path to the folder where extracted tokens are saved. +pretrain_embeddings_folder: non + +avoid_if_longer_than: 10.0 # Training parameters number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 + +batch_size_exponent: 4 # @orion_step1: --batch_size_exponent~"uniform(2, 4,discrete=True)" +batch_size: !ref 2 ** test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 -### Config for Tokenizer -vocab_size: 1024 -num_codebooks: 2 -sample_rate: 16000 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 +token_type: bpe # ["unigram", "bpe", "char"] +character_coverage: 1.0 -# Feature parameters +lr_model: 0.0002 # @orion_step1: --lr_model~"loguniform(0.00001,0.5)" + +# Training parameters +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref -encoder_dim: 1024 # Dataloader options train_dataloader_opts: batch_size: !ref +dataloader_options: + batch_size: !ref + num_workers: 4 +test_dataloader_options: + batch_size: !ref + num_workers: 4 + valid_dataloader_opts: batch_size: !ref -test_dataloader_opts: - batch_size: !ref - # Model parameters + activation: !name:torch.nn.Sigmoid dnn_layers: 1 -dnn_neurons: 1024 +dnn_neurons: 768 freeze_encoder: True # Outputs -output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 +output_neurons: 100 # BPE size, index(blank/eos/bos) = 0 # Decoding parameters blank_index: 0 @@ -92,16 +125,20 @@ test_beam_search: # If you don't want to use an LM, comment it out or set it to null kenlm_model_path: null +### Config for Tokenizer +vocab_size: 1024 +num_codebooks: 2 +sample_rate: 16000 + +# Feature parameters +encoder_dim: 1024 + # Functions and classes # epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -codec: !new:speechbrain.lobes.models.discrete.speechtokenizer_interface.SpeechTokenizer_interface - source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio - save_path: !ref +# Modules discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -111,6 +148,7 @@ attention_mlp: !new:custom_model.AttentionMLP input_dim: !ref hidden_dim: !ref + enc: !new:speechbrain.nnet.RNN.LSTM input_shape: [Null, Null, !ref ] num_layers: 2 @@ -132,17 +170,16 @@ modules: enc: !ref ctc_lin: !ref attention_mlp: !ref - codec: !ref discrete_embedding_layer: !ref model: !new:torch.nn.ModuleList - [!ref , !ref , !ref , !ref ] model_opt_class: !name:torch.optim.Adam - lr: !ref + lr: !ref lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref + initial_value: !ref improvement_threshold: 0.0025 annealing_factor: 0.8 patient: 0 @@ -155,7 +192,6 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer model: !ref scheduler_model: !ref attention_mlp: !ref - codec: !ref discrete_embedding_layer: !ref counter: !ref tokenizer: !ref diff --git a/benchmarks/DASB/CommonVoice/ASR/LSTM/hparams/train_dac.yaml b/benchmarks/DASB/CommonVoice/ASR/LSTM/hparams/train_dac.yaml index d80230db7..22f06450c 100644 --- a/benchmarks/DASB/CommonVoice/ASR/LSTM/hparams/train_dac.yaml +++ b/benchmarks/DASB/CommonVoice/ASR/LSTM/hparams/train_dac.yaml @@ -19,6 +19,7 @@ train_log: !ref /train_log.txt # Data files data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/ +cached_data_folder: !PLACEHOLDER # e.g., path/to/cache train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files @@ -28,6 +29,9 @@ valid_csv: !ref /dev.csv test_csv: !ref /test.csv skip_prep: False # Skip data preparation +tokens_folder: !PLACEHOLDER # Path to the folder where extracted tokens are saved. +pretrain_embeddings_folder: none # Optional: If pretrain_embeddings is True, this should be set to the path where the pretrained embeddings are saved. + avoid_if_longer_than: 10.0 # Training parameters @@ -97,6 +101,8 @@ vocab_size: 1024 model_bitrate: 8kbps num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type sample_rate: 24000 +pretrain_embeddings: False +freeze_embedding: False # Feature parameters encoder_dim: 1024 diff --git a/benchmarks/DASB/CommonVoice/ASR/common_voice_prepare.py b/benchmarks/DASB/CommonVoice/ASR/common_voice_prepare.py new file mode 100644 index 000000000..1f17faa66 --- /dev/null +++ b/benchmarks/DASB/CommonVoice/ASR/common_voice_prepare.py @@ -0,0 +1,449 @@ +""" +Data preparation. +Download: https://commonvoice.mozilla.org/en/datasets +Author +------ +Titouan Parcollet +Luca Della Libera 2022 +Pooneh Mousavi 2022 +Salima Mdhaffar 2023 +""" + +from dataclasses import dataclass +import os +import csv +import re +import logging +import unicodedata +import functools + +from speechbrain.utils.parallel import parallel_map +from speechbrain.dataio.dataio import read_audio_info + +logger = logging.getLogger(__name__) + + +def prepare_common_voice( + data_folder, + save_folder, + train_tsv_file=None, + dev_tsv_file=None, + test_tsv_file=None, + accented_letters=False, + language="en", + skip_prep=False, +): + """ + Prepares the csv files for the Mozilla Common Voice dataset. + Download: https://commonvoice.mozilla.org/en + + Arguments + --------- + data_folder : str + Path to the folder where the original Common Voice dataset is stored. + This path should include the lang: /datasets/CommonVoice// + save_folder : str + The directory where to store the csv files. + train_tsv_file : str, optional + Path to the Train Common Voice .tsv file (cs) + dev_tsv_file : str, optional + Path to the Dev Common Voice .tsv file (cs) + test_tsv_file : str, optional + Path to the Test Common Voice .tsv file (cs) + accented_letters : bool, optional + Defines if accented letters will be kept as individual letters or + transformed to the closest non-accented letters. + language: str + Specify the language for text normalization. + skip_prep: bool + If True, skip data preparation. + Example + ------- + >>> from recipes.CommonVoice.common_voice_prepare import prepare_common_voice + >>> data_folder = '/datasets/CommonVoice/en' + >>> save_folder = 'exp/CommonVoice_exp' + >>> train_tsv_file = '/datasets/CommonVoice/en/train.tsv' + >>> dev_tsv_file = '/datasets/CommonVoice/en/dev.tsv' + >>> test_tsv_file = '/datasets/CommonVoice/en/test.tsv' + >>> accented_letters = False + >>> duration_threshold = 10 + >>> prepare_common_voice( \ + data_folder, \ + save_folder, \ + train_tsv_file, \ + dev_tsv_file, \ + test_tsv_file, \ + accented_letters, \ + language="en" \ + ) + """ + + if skip_prep: + return + + # If not specified point toward standard location w.r.t CommonVoice tree + if train_tsv_file is None: + train_tsv_file = data_folder + "/train.tsv" + else: + train_tsv_file = train_tsv_file + + if dev_tsv_file is None: + dev_tsv_file = data_folder + "/dev.tsv" + else: + dev_tsv_file = dev_tsv_file + + if test_tsv_file is None: + test_tsv_file = data_folder + "/test.tsv" + else: + test_tsv_file = test_tsv_file + + # Setting the save folder + if not os.path.exists(save_folder): + os.makedirs(save_folder) + + # Setting ouput files + save_csv_train = save_folder + "/train.csv" + save_csv_dev = save_folder + "/dev.csv" + save_csv_test = save_folder + "/test.csv" + + # If csv already exists, we skip the data preparation + if skip(save_csv_train, save_csv_dev, save_csv_test): + + msg = "%s already exists, skipping data preparation!" % (save_csv_train) + logger.info(msg) + + msg = "%s already exists, skipping data preparation!" % (save_csv_dev) + logger.info(msg) + + msg = "%s already exists, skipping data preparation!" % (save_csv_test) + logger.info(msg) + + return + + # Additional checks to make sure the data folder contains Common Voice + check_commonvoice_folders(data_folder) + # Creating csv files for {train, dev, test} data + file_pairs = zip( + [train_tsv_file, dev_tsv_file, test_tsv_file], + [save_csv_train, save_csv_dev, save_csv_test], + ) + for tsv_file, save_csv in file_pairs: + create_csv( + tsv_file, save_csv, data_folder, accented_letters, language, + ) + + +def skip(save_csv_train, save_csv_dev, save_csv_test): + """ + Detects if the Common Voice data preparation has been already done. + If the preparation has been done, we can skip it. + Returns + ------- + bool + if True, the preparation phase can be skipped. + if False, it must be done. + """ + + # Checking folders and save options + skip = False + + if ( + os.path.isfile(save_csv_train) + and os.path.isfile(save_csv_dev) + and os.path.isfile(save_csv_test) + ): + skip = True + + return skip + + +@dataclass +class CVRow: + snt_id: str + duration: float + mp3_path: str + spk_id: str + words: str + + +def process_line(line, data_folder, language, accented_letters): + # Path is at indice 1 in Common Voice tsv files. And .mp3 files + # are located in datasets/lang/clips/ + mp3_path = data_folder + "/clips/" + line.split("\t")[1] + + file_name = mp3_path.split(".")[-2].split("/")[-1] + spk_id = line.split("\t")[0] + snt_id = file_name + + # Reading the signal (to retrieve duration in seconds) + if os.path.isfile(mp3_path): + info = read_audio_info(mp3_path) + else: + msg = "\tError loading: %s" % (str(len(file_name))) + logger.info(msg) + return None + + duration = info.num_frames / info.sample_rate + + # Getting transcript + words = line.split("\t")[3] + + # Unicode Normalization + words = unicode_normalisation(words) + + # !! Language specific cleaning !! + words = language_specific_preprocess(language, words) + + # Remove accents if specified + if not accented_letters: + words = strip_accents(words) + words = words.replace("'", " ") + words = words.replace("’", " ") + + # Remove multiple spaces + words = re.sub(" +", " ", words) + + # Remove spaces at the beginning and the end of the sentence + words = words.lstrip().rstrip() + + # Getting chars + chars = words.replace(" ", "_") + chars = " ".join([char for char in chars][:]) + + # Remove too short sentences (or empty): + if language in ["ja", "zh-CN"]: + if len(chars) < 3: + return None + else: + if len(words.split(" ")) < 3: + return None + + # Composition of the csv_line + return CVRow(snt_id, duration, mp3_path, spk_id, words) + + +def create_csv( + orig_tsv_file, csv_file, data_folder, accented_letters=False, language="en" +): + """ + Creates the csv file given a list of wav files. + Arguments + --------- + orig_tsv_file : str + Path to the Common Voice tsv file (standard file). + data_folder : str + Path of the CommonVoice dataset. + accented_letters : bool, optional + Defines if accented letters will be kept as individual letters or + transformed to the closest non-accented letters. + Returns + ------- + None + """ + + # Check if the given files exists + if not os.path.isfile(orig_tsv_file): + msg = "\t%s doesn't exist, verify your dataset!" % (orig_tsv_file) + logger.info(msg) + raise FileNotFoundError(msg) + + # We load and skip the header + loaded_csv = open(orig_tsv_file, "r").readlines()[1:] + nb_samples = len(loaded_csv) + + msg = "Preparing CSV files for %s samples ..." % (str(nb_samples)) + logger.info(msg) + + # Adding some Prints + msg = "Creating csv lists in %s ..." % (csv_file) + logger.info(msg) + + # Process and write lines + total_duration = 0.0 + + line_processor = functools.partial( + process_line, + data_folder=data_folder, + language=language, + accented_letters=accented_letters, + ) + + # Stream into a .tmp file, and rename it to the real path at the end. + csv_file_tmp = csv_file + ".tmp" + + with open(csv_file_tmp, mode="w", encoding="utf-8") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + + csv_writer.writerow(["ID", "duration", "wav", "spk_id", "wrd"]) + + for row in parallel_map(line_processor, loaded_csv): + if row is None: + continue + + total_duration += row.duration + csv_writer.writerow( + [ + row.snt_id, + str(row.duration), + row.mp3_path, + row.spk_id, + row.words, + ] + ) + + os.replace(csv_file_tmp, csv_file) + + # Final prints + msg = "%s successfully created!" % (csv_file) + logger.info(msg) + msg = "Number of samples: %s " % (str(len(loaded_csv))) + logger.info(msg) + msg = "Total duration: %s Hours" % (str(round(total_duration / 3600, 2))) + logger.info(msg) + + +def language_specific_preprocess(language, words): + # !! Language specific cleaning !! + # Important: feel free to specify the text normalization + # corresponding to your alphabet. + + if language in ["en", "fr", "it", "rw"]: + words = re.sub( + "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words + ).upper() + + if language == "de": + # this replacement helps preserve the case of ß + # (and helps retain solitary occurrences of SS) + # since python's upper() converts ß to SS. + words = words.replace("ß", "0000ß0000") + words = re.sub("[^’'A-Za-z0-9öÖäÄüÜß]+", " ", words).upper() + words = words.replace("'", " ") + words = words.replace("’", " ") + words = words.replace( + "0000SS0000", "ß" + ) # replace 0000SS0000 back to ß as its initial presence in the corpus + + elif language == "fr": # SM + words = re.sub( + "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words + ) + words = words.replace("’", "'") + words = words.replace("é", "é") + words = words.replace("æ", "ae") + words = words.replace("œ", "oe") + words = words.replace("â", "â") + words = words.replace("ç", "ç") + words = words.replace("è", "è") + words = words.replace("à", "à") + words = words.replace("û", "û") + words = words.replace("î", "î") + words = words.upper() + + # Case of apostrophe collés + words = words.replace("L'", "L' ") + words = words.replace("L' ", "L' ") + words = words.replace("S'", "S' ") + words = words.replace("S' ", "S' ") + words = words.replace("D'", "D' ") + words = words.replace("D' ", "D' ") + words = words.replace("J'", "J' ") + words = words.replace("J' ", "J' ") + words = words.replace("N'", "N' ") + words = words.replace("N' ", "N' ") + words = words.replace("C'", "C' ") + words = words.replace("C' ", "C' ") + words = words.replace("QU'", "QU' ") + words = words.replace("QU' ", "QU' ") + words = words.replace("M'", "M' ") + words = words.replace("M' ", "M' ") + + # Case of apostrophe qui encadre quelques mots + words = words.replace(" '", " ") + words = words.replace("A'", "A") + words = words.replace("B'", "B") + words = words.replace("E'", "E") + words = words.replace("F'", "F") + words = words.replace("G'", "G") + words = words.replace("K'", "K") + words = words.replace("Q'", "Q") + words = words.replace("V'", "V") + words = words.replace("W'", "W") + words = words.replace("Z'", "Z") + words = words.replace("O'", "O") + words = words.replace("X'", "X") + words = words.replace("AUJOURD' HUI", "AUJOURD'HUI") + elif language == "ar": + HAMZA = "\u0621" + ALEF_MADDA = "\u0622" + ALEF_HAMZA_ABOVE = "\u0623" + letters = ( + "ابتةثجحخدذرزژشسصضطظعغفقكلمنهويىءآأؤإئ" + + HAMZA + + ALEF_MADDA + + ALEF_HAMZA_ABOVE + ) + words = re.sub("[^" + letters + " ]+", "", words).upper() + elif language == "fa": + HAMZA = "\u0621" + ALEF_MADDA = "\u0622" + ALEF_HAMZA_ABOVE = "\u0623" + letters = ( + "ابپتةثجحخچدذرزژسشصضطظعغفقگکلمنهویىءآأؤإئ" + + HAMZA + + ALEF_MADDA + + ALEF_HAMZA_ABOVE + ) + words = re.sub("[^" + letters + " ]+", "", words).upper() + elif language == "ga-IE": + # Irish lower() is complicated, but upper() is nondeterministic, so use lowercase + def pfxuc(a): + return len(a) >= 2 and a[0] in "tn" and a[1] in "AEIOUÁÉÍÓÚ" + + def galc(w): + return w.lower() if not pfxuc(w) else w[0] + "-" + w[1:].lower() + + words = re.sub("[^-A-Za-z'ÁÉÍÓÚáéíóú]+", " ", words) + words = " ".join(map(galc, words.split(" "))) + elif language == "es": + # Fix the following error in dataset large: + # KeyError: 'The item En noviembre lanzaron Queen Elizabeth , coproducida por Foreign Noi$e . requires replacements which were not supplied.' + words = words.replace("$", "s") + return words + + +def check_commonvoice_folders(data_folder): + """ + Check if the data folder actually contains the Common Voice dataset. + If not, raises an error. + Returns + ------- + None + Raises + ------ + FileNotFoundError + If data folder doesn't contain Common Voice dataset. + """ + files_str = "/clips" + # Checking clips + if not os.path.exists(data_folder + files_str): + err_msg = ( + "the folder %s does not exist (it is expected in " + "the Common Voice dataset)" % (data_folder + files_str) + ) + raise FileNotFoundError(err_msg) + + +def unicode_normalisation(text): + return str(text) + + +def strip_accents(text): + text = ( + unicodedata.normalize("NFD", text) + .encode("ascii", "ignore") + .decode("utf-8") + ) + return str(text) \ No newline at end of file diff --git a/benchmarks/DASB/CommonVoice/ASR/linear/hparams/train.yaml b/benchmarks/DASB/CommonVoice/ASR/linear/hparams/train.yaml new file mode 100644 index 000000000..c81c4559b --- /dev/null +++ b/benchmarks/DASB/CommonVoice/ASR/linear/hparams/train.yaml @@ -0,0 +1,218 @@ +# ################################ +# Script for training an ASR model evaluating an SSL representation +# model on one language from the CommonVoice dataset. A SentencePiece tokenizer +# with number of tokens equal to is learned in a first phase +# on the considered language. +# +# Authors +# * Pooneh Mousavi 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +language: cy # use 'cy' for Welsh and 'eu' for Basque +output_folder: !ref results/CommonVoice//linear/speech_tokenizer/ +test_wer_file: !ref /wer_test.txt +save_folder: !ref /save +train_log: !ref /train_log.txt +cached_data_folder: cache/CommonVoice//linear/speech_tokenizer/ +run_name: !PLACEHOLDER +# Data files +data_folder: /data/anakuzne/cv/cv-corpus-17.0-2024-03-15/cy # e.g, /local/cv-corpus-11.0-2022-09-21/ +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation +testing: True # If set to True, the test evlaution is done, otherwise skipped. + +tokens_folder: !PLACEHOLDER # Path to the folder where extracted tokens are saved. +pretrain_embeddings_folder: none # Optional: If pretrain_embeddings is True, this should be set to the path where the pretrained embeddings are saved. + +avoid_if_longer_than: 10.0 + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size_exponent: 4 # @orion_step1: --batch_size_exponent~"uniform(2, 4,discrete=True)" +batch_size: !ref 2 ** +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 + +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.0002 # @orion_step1: --lr_model~"loguniform(0.00001,0.5)" + +# Training parameters +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + +dataloader_options: + batch_size: !ref + num_workers: 4 + +test_dataloader_options: + batch_size: !ref + num_workers: 4 + +valid_dataloader_opts: + batch_size: !ref + +####################### Model parameters ########################### +# Tokenizer parameters +# These parameters should be set according to the tokenizer used to extract tokens saved in . +vocab_size: 1024 +num_codebooks: 8 +sample_rate: 16000 + +# Feature parameters +encoder_dim: 1024 +# If set to True, encoder_dim should match the dimension of the tokenizer. For Encodec, it is 128. +pretrain_embeddings: False +freeze_embedding: False + +# Linear +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 # @orion_step1: --dnn_layers~"uniform(1, 4,discrete=True)" +dnn_neurons: 2048 +dropout: 0.2 +output_neurons: 100 + +# BPE parameters +token_type: bpe # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +unk_index: 1 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +tokens_loader: !new:utils.tokens.TokensLoader + data_path: !ref + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.VanillaNN.VanillaNN + input_shape: [null, null, !ref ] + activation: !ref + dnn_blocks: !ref + dnn_neurons: !ref + + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + beam_size: 143 + topk: 1 + blank_index: !ref + space_token: ' ' # make sure this is the same as the one used in the tokenizer + beam_prune_logp: -12.0 + token_prune_min_logp: -1.2 + prune_history: True + alpha: 0.8 + beta: 1.2 + # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM + # It can either be a .bin or .arpa ; note: .arpa is much slower at loading + # If you don't want to use an LM, comment it out or set it to null + kenlm_model_path: null + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler_model: !ref + attention_mlp: !ref + discrete_embedding_layer: !ref + counter: !ref + tokenizer: !ref + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/DASB/CommonVoice/ASR/linear/hparams/train_speech_tokenizer.yaml b/benchmarks/DASB/CommonVoice/ASR/linear/hparams/train_speech_tokenizer.yaml index 69d471448..264bef5fd 100644 --- a/benchmarks/DASB/CommonVoice/ASR/linear/hparams/train_speech_tokenizer.yaml +++ b/benchmarks/DASB/CommonVoice/ASR/linear/hparams/train_speech_tokenizer.yaml @@ -19,7 +19,7 @@ train_log: !ref /train_log.txt # Data files -data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/ +data_folder: /data/anakuzne/cv/cv-corpus-17.0-2024-03-15/cy # e.g, /local/cv-corpus-11.0-2022-09-21/ train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_encodec.py b/benchmarks/DASB/CommonVoice/ASR/train.py similarity index 51% rename from benchmarks/DASB/LibriSpeech/ASR/LSTM/train_encodec.py rename to benchmarks/DASB/CommonVoice/ASR/train.py index d2215ce45..134147776 100644 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_encodec.py +++ b/benchmarks/DASB/CommonVoice/ASR/train.py @@ -1,24 +1,27 @@ #!/usr/bin/env/python3 -"""Recipe for training an SSL-based ctc ASR system with librispeech. - -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. +""" Script for training an ASR model evaluating an SSL representation +model on one language from the CommonVoice dataset. A SentencePiece tokenizer +with number of tokens equal to is learned in a first phase, on +the considered language. Authors - * Adel Moumen 2024 - * Salah Zaiem 2023 - * Youcef Kemiche 2023 + * Pooneh Mousavi 2024 """ -import os import sys import torch -import torchaudio import logging import speechbrain as sb from speechbrain.utils.distributed import run_on_main, if_main_process from hyperpyyaml import load_hyperpyyaml -from pathlib import Path +import torchaudio +from speechbrain.tokenizers.SentencePiece import SentencePiece +from speechbrain.utils.data_utils import undo_padding +import time +import os + +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) logger = logging.getLogger(__name__) @@ -29,51 +32,55 @@ def compute_forward(self, batch, stage): """Forward computations from the waveform batches to the output probabilities.""" batch = batch.to(self.device) wavs, wav_lens = batch.sig + p_tokens, _ = batch.speech_tokens - # Forward pass - # Feature extraction and attention pooling - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens, _ = self.hparams.codec.encode(wavs, wav_lens) - embeddings = self.modules.discrete_embedding_layer(tokens) + embeddings = self.modules.discrete_embedding_layer(p_tokens) att_w = self.modules.attention_mlp(embeddings) feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - y = y[0] # As it is an RNN output - # Compute outputs + + if type(self.modules.encoder).__name__ == "VanillaNN": + enc_out = self.modules.encoder(feats) + + elif type(self.modules.encoder).__name__ == "LSTM": + enc_out, _ = self.modules.encoder(feats) + + else: + raise NotImplementedError + p_tokens = None - logits = self.modules.ctc_lin(y) + + # output layer for ctc log-probabilities + logits = self.modules.ctc_lin(enc_out) p_ctc = self.hparams.log_softmax(logits) + if stage == sb.Stage.VALID: p_tokens = sb.decoders.ctc_greedy_decode( p_ctc, wav_lens, blank_id=self.hparams.blank_index ) elif stage == sb.Stage.TEST: p_tokens = test_searcher(p_ctc, wav_lens) - return p_ctc, wav_lens, p_tokens def compute_objectives(self, predictions, batch, stage): """Computes the loss (CTC+NLL) given predictions and targets.""" - p_ctc, wav_lens, predicted_tokens = predictions + p_ctc, wav_lens, p_tokens = predictions ids = batch.id tokens, tokens_lens = batch.tokens loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] + # Convert token indices to words + predicted_words = self.tokenizer(p_tokens, task="decode_from_list") + elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] + predicted_words = [hyp[0].text.split(" ") for hyp in p_tokens] if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] + # Convert indices to words + target_words = undo_padding(tokens, tokens_lens) + target_words = self.tokenizer(target_words, task="decode_from_list") + self.wer_metric.append(ids, predicted_words, target_words) self.cer_metric.append(ids, predicted_words, target_words) @@ -97,18 +104,12 @@ def on_stage_end(self, stage, stage_loss, epoch): # Perform end-of-iteration things, like annealing, logging, etc. if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( + old_lr_model, new_lr_model = self.hparams.scheduler( stage_stats["loss"] ) - # old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( - # stage_stats["loss"] - # ) sb.nnet.schedulers.update_learning_rate( self.model_optimizer, new_lr_model ) - # sb.nnet.schedulers.update_learning_rate( - # self.weights_optimizer, new_lr_weights - # ) self.hparams.train_logger.log_stats( stats_meta={"epoch": epoch, "lr_model": old_lr_model}, @@ -116,7 +117,9 @@ def on_stage_end(self, stage, stage_loss, epoch): valid_stats=stage_stats, ) self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + meta={"WER": stage_stats["WER"], "epoch": epoch}, + min_keys=["WER"], + num_to_keep=self.hparams.avg_checkpoints, ) elif stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( @@ -129,45 +132,48 @@ def on_stage_end(self, stage, stage_loss, epoch): def init_optimizers(self): "Initializes the weights optimizer and model optimizer" - # self.weights_optimizer = self.hparams.weights_opt_class( - # self.hparams.attention_mlp.parameters() - # ) self.model_optimizer = self.hparams.model_opt_class( self.hparams.model.parameters() ) self.optimizers_dict = { - # "weights_optimizer": self.weights_optimizer, "model_optimizer": self.model_optimizer, } # Initializing the weights if self.checkpointer is not None: self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - # self.checkpointer.add_recoverable( - # "weights_opt", self.weights_optimizer - # ) -def dataio_prepare(hparams): +# Define custom data procedure +def dataio_prepare(hparams, tokenizer): """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" + It also defines the data processing pipeline through user-defined functions. + """ + + # 1. Define datasets data_folder = hparams["data_folder"] train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + csv_path=hparams["train_csv"], + replacements={"data_root": data_folder}, ) if hparams["sorting"] == "ascending": # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") + train_data = train_data.filtered_sorted( + sort_key="duration", + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + ) # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False + hparams["dataloader_options"]["shuffle"] = False elif hparams["sorting"] == "descending": train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True + sort_key="duration", + reverse=True, + key_max_value={"duration": hparams["avoid_if_longer_than"]}, ) # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False + hparams["dataloader_options"]["shuffle"] = False elif hparams["sorting"] == "random": pass @@ -178,77 +184,69 @@ def dataio_prepare(hparams): ) valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + csv_path=hparams["valid_csv"], + replacements={"data_root": data_folder}, ) + # We also sort the validation data so it is faster to validate valid_data = valid_data.filtered_sorted(sort_key="duration") - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_csv"], + replacements={"data_root": data_folder}, + ) - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] + # We also sort the validation data so it is faster to validate + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # 1. Define tokens pipeline: + tokens_loader = hparams["tokens_loader"] + num_codebooks = hparams["num_codebooks"] + + @sb.utils.data_pipeline.takes("id") + @sb.utils.data_pipeline.provides("speech_tokens") + def tokens_pipeline(id): + tokens = tokens_loader.tokens_by_uttid(id, num_codebooks=num_codebooks) + return tokens + + sb.dataio.dataset.add_dynamic_item(datasets, tokens_pipeline) # 2. Define audio pipeline: @sb.utils.data_pipeline.takes("wav") @sb.utils.data_pipeline.provides("sig") def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) info = torchaudio.info(wav) + sig = sb.dataio.dataio.read_audio(wav) resampled = torchaudio.transforms.Resample( - info.sample_rate, hparams["sample_rate"], + info.sample_rate, + hparams["sample_rate"], )(sig) - # resampled = resampled.unsqueeze(0) return resampled sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() # 3. Define text pipeline: @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) + @sb.utils.data_pipeline.provides("tokens_list", "tokens") def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) + tokens_list = tokenizer.sp.encode_as_ids(wrd) yield tokens_list tokens = torch.LongTensor(tokens_list) yield tokens sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - # 4. Set output: sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], + datasets, + ["id", "sig", "tokens", "speech_tokens"], ) - return train_data, valid_data, test_datasets, label_encoder + return train_data, valid_data, test_data if __name__ == "__main__": - # CLI: hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) # If distributed_launch=True then @@ -265,34 +263,73 @@ def text_pipeline(wrd): overrides=overrides, ) - if hparams["discrete_embedding_layer"].init: - hparams["discrete_embedding_layer"].init_embedding( - hparams["codec"] - .vocabulary[: hparams["num_codebooks"], :, :] - .flatten(0, 1) - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa + # Dataset preparation + from common_voice_prepare import prepare_common_voice # noqa # multi-gpu (ddp) save data preparation + # Due to DDP, we do the preparation ONLY on the main python process run_on_main( - prepare_librispeech, + prepare_common_voice, kwargs={ "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", + "save_folder": hparams["save_folder"], + "train_tsv_file": hparams["train_tsv_file"], + "dev_tsv_file": hparams["dev_tsv_file"], + "test_tsv_file": hparams["test_tsv_file"], + "accented_letters": hparams["accented_letters"], + "language": hparams["language"], "skip_prep": hparams["skip_prep"], }, ) - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams + # Defining tokenizer and loading it + tokenizer = SentencePiece( + model_dir=hparams["save_folder"], + vocab_size=hparams["output_neurons"], # Number of considered tokens + annotation_train=hparams["train_csv"], + annotation_read="wrd", + model_type=hparams["token_type"], + character_coverage=hparams["character_coverage"], + ) + + # Create the datasets objects as well as tokenization and encoding :-D + train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer) + + # Use pretrained embeddings + if hparams["pretrain_embeddings"]: + tokens_loader = hparams["tokens_loader"] + embs = tokens_loader.load_pretrained_embeddings( + hparams["pretain_embeddings_folder"] + ) + if isinstance(hparams["num_codebooks"], int): + embs = embs[: hparams["num_codebooks"] * hparams["vocab_size"],] + # For discrete SSL, num_codebooks is a list used to determine which layers to use. + # It is not sequential and can be, for example, [0, 1] or [1, 4]. + elif isinstance(hparams["num_codebooks"], list): + indices = [ + i + for codebook_idx in hparams["num_codebooks"] + for i in range( + codebook_idx * hparams["vocab_size"], + (codebook_idx + 1) * hparams["vocab_size"], + ) + ] + indices = torch.tensor(indices, dtype=torch.long) + embs = embs[indices] + hparams["discrete_embedding_layer"].init_embedding(embs) + + # Log number of parameters/buffers + model_params = sum( + [ + x.numel() + for module in hparams["modules"].values() + for x in module.state_dict().values() + ] + ) + hparams["train_logger"].log_stats( + stats_meta={ + "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", + }, ) # Trainer initialization @@ -303,20 +340,21 @@ def text_pipeline(wrd): checkpointer=hparams["checkpointer"], ) - # Loading the SSL model - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] + # Adding objects to trainer. + asr_brain.tokenizer = tokenizer + vocab_list = [ + tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size()) + ] from speechbrain.decoders.ctc import CTCBeamSearcher test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, + **hparams["test_beam_search"], + vocab_list=vocab_list, ) # Training + start_time = time.time() # Start the timer asr_brain.fit( asr_brain.hparams.epoch_counter, train_data, @@ -325,16 +363,14 @@ def text_pipeline(wrd): valid_loader_kwargs=hparams["valid_dataloader_opts"], ) - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) + end_time = time.time() # End the timer + # Calculate elapsed time + elapsed_time = end_time - start_time + logger.info(f"Model execution time: {elapsed_time:.6f} seconds") - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) + # Testing + asr_brain.evaluate( + test_data, + min_key="WER", + test_loader_kwargs=hparams["test_dataloader_options"], + ) diff --git a/benchmarks/DASB/CommonVoice/extraction/common_voice_prepare.py b/benchmarks/DASB/CommonVoice/extraction/common_voice_prepare.py new file mode 100644 index 000000000..5b0c7c875 --- /dev/null +++ b/benchmarks/DASB/CommonVoice/extraction/common_voice_prepare.py @@ -0,0 +1,571 @@ +""" +Data preparation. +Download: https://commonvoice.mozilla.org/en/datasets +Author +------ +Titouan Parcollet 2021, 2022, 2024 +Luca Della Libera 2022 +Pooneh Mousavi 2022 +Salima Mdhaffar 2023 +Adel Moumen 2024 +""" + +import csv +import functools +import os +import re +import unicodedata +from dataclasses import dataclass + +from speechbrain.dataio.dataio import read_audio_info +from speechbrain.utils.logger import get_logger +from speechbrain.utils.parallel import parallel_map + +logger = get_logger(__name__) + +VERBOSE = False +SAMPLING_RATE = 16_000 + + +def prepare_common_voice( + data_folder, + save_folder, + train_tsv_file=None, + dev_tsv_file=None, + test_tsv_file=None, + accented_letters=False, + language="en", + skip_prep=False, + convert_to_wav=False, +): + """ + Prepares the csv files for the Mozilla Common Voice dataset. + Download: https://commonvoice.mozilla.org/en + + Arguments + --------- + data_folder : str + Path to the folder where the original Common Voice dataset is stored. + This path should include the lang: /datasets/CommonVoice// + save_folder : str + The directory where to store the csv files. + train_tsv_file : str, optional + Path to the Train Common Voice .tsv file (cs) + dev_tsv_file : str, optional + Path to the Dev Common Voice .tsv file (cs) + test_tsv_file : str, optional + Path to the Test Common Voice .tsv file (cs) + accented_letters : bool, optional + Defines if accented letters will be kept as individual letters or + transformed to the closest non-accented letters. + language: str + Specify the language for text normalization. + skip_prep: bool + If True, skip data preparation. + convert_to_wav: bool + If True, `.mp3` files are converted (duplicated) to uncompressed `.wav`. + Uncompressed `wav`s can be much faster to decode than MP3, at the cost + of much higher disk usage and bandwidth. This might be useful if you are + CPU-limited in workers during training. + This invokes the `ffmpeg` commandline, so ffmpeg must be installed. + + Returns + ------- + None + + Example + ------- + >>> from recipes.CommonVoice.common_voice_prepare import prepare_common_voice + >>> data_folder = '/datasets/CommonVoice/en' + >>> save_folder = 'exp/CommonVoice_exp' + >>> train_tsv_file = '/datasets/CommonVoice/en/train.tsv' + >>> dev_tsv_file = '/datasets/CommonVoice/en/dev.tsv' + >>> test_tsv_file = '/datasets/CommonVoice/en/test.tsv' + >>> accented_letters = False + >>> duration_threshold = 10 + >>> prepare_common_voice( \ + data_folder, \ + save_folder, \ + train_tsv_file, \ + dev_tsv_file, \ + test_tsv_file, \ + accented_letters, \ + language="en" \ + ) + """ + + if skip_prep: + return + + # If not specified point toward standard location w.r.t CommonVoice tree + if train_tsv_file is None: + train_tsv_file = data_folder + "/train.tsv" + else: + train_tsv_file = train_tsv_file + + if dev_tsv_file is None: + dev_tsv_file = data_folder + "/dev.tsv" + else: + dev_tsv_file = dev_tsv_file + + if test_tsv_file is None: + test_tsv_file = data_folder + "/test.tsv" + else: + test_tsv_file = test_tsv_file + + # Setting the save folder + os.makedirs(save_folder, exist_ok=True) + + # Setting output files + save_csv_train = save_folder + "/train.csv" + save_csv_dev = save_folder + "/dev.csv" + save_csv_test = save_folder + "/test.csv" + + # If csv already exists, we skip the data preparation + if skip(save_csv_train, save_csv_dev, save_csv_test): + msg = "%s already exists, skipping data preparation!" % (save_csv_train) + logger.info(msg) + + msg = "%s already exists, skipping data preparation!" % (save_csv_dev) + logger.info(msg) + + msg = "%s already exists, skipping data preparation!" % (save_csv_test) + logger.info(msg) + + return + + # Additional checks to make sure the data folder contains Common Voice + check_commonvoice_folders(data_folder) + # Creating csv files for {train, dev, test} data + file_pairs = zip( + [train_tsv_file, dev_tsv_file, test_tsv_file], + [save_csv_train, save_csv_dev, save_csv_test], + ) + for tsv_file, save_csv in file_pairs: + create_csv( + convert_to_wav, + tsv_file, + save_csv, + data_folder, + accented_letters, + language, + ) + + +def skip(save_csv_train, save_csv_dev, save_csv_test): + """ + Detects if the Common Voice data preparation has been already done. + If the preparation has been done, we can skip it. + + Arguments + --------- + save_csv_train : str + The train csv file + save_csv_dev : str + The dev csv file + save_csv_test : str + The test csv file + + Returns + ------- + bool + if True, the preparation phase can be skipped. + if False, it must be done. + """ + + # Checking folders and save options + skip = False + + if ( + os.path.isfile(save_csv_train) + and os.path.isfile(save_csv_dev) + and os.path.isfile(save_csv_test) + ): + skip = True + + return skip + + +@dataclass +class CVRow: + snt_id: str + duration: float + audio_path: str + spk_id: str + words: str + + +def process_line( + line, convert_to_wav, data_folder, language, accented_letters, header_map +): + """Process a line of CommonVoice tsv file. + + Arguments + --------- + line : str + A line of the CommonVoice tsv file. + convert_to_wav : bool + If True, `.mp3` files are converted (duplicated) to uncompressed `.wav`. + Uncompressed `wav`s can be much faster to decode than MP3, at the cost + of much higher disk usage and bandwidth. This might be useful if you are + CPU-limited in workers during training. + This invokes the `ffmpeg` commandline, so ffmpeg must be installed. + data_folder : str + Path to the CommonVoice dataset. + language : str + Language code, e.g. "en" + accented_letters : bool + Defines if accented letters will be kept as individual letters or + transformed to the closest non-accented letters. + header_map : Dict[str, int] + Map from column name to column indices + + Returns + ------- + CVRow + A dataclass containing the information about the line. + """ + + columns = line.strip().split("\t") + spk_id = columns[header_map["client_id"]] + audio_path_filename = columns[header_map["path"]] + words = columns[header_map["sentence"]] + + # Path is at indice 1 in Common Voice tsv files. And .mp3 files + # are located in datasets/lang/clips/ + audio_path = data_folder + "/clips/" + audio_path_filename + + if convert_to_wav: + audio_path = convert_mp3_to_wav(audio_path) + + file_name = audio_path.split(".")[-2].split("/")[-1] + snt_id = file_name + + # Reading the signal (to retrieve duration in seconds) + if os.path.isfile(audio_path): + info = read_audio_info(audio_path) + else: + msg = "\tError loading: %s" % (str(len(file_name))) + logger.info(msg) + return None + + duration = info.num_frames / info.sample_rate + + # Getting transcript + + # Unicode Normalization + words = unicode_normalisation(words) + + # !! Language specific cleaning !! + words = language_specific_preprocess(language, words) + + # Remove accents if specified + if not accented_letters: + words = strip_accents(words) + words = words.replace("'", " ") + words = words.replace("’", " ") + + # Remove multiple spaces + words = re.sub(" +", " ", words) + + # Remove spaces at the beginning and the end of the sentence + words = words.lstrip().rstrip() + + # Getting chars + chars = words.replace(" ", "_") + chars = " ".join([char for char in chars][:]) + + # Remove too short sentences (or empty): + if language in ["ja", "zh-CN"]: + if len(chars) < 3: + return None + else: + if len(words.split(" ")) < 3: + return None + + # Composition of the csv_line + return CVRow(snt_id, duration, audio_path, spk_id, words) + + +def create_csv( + convert_to_wav, + orig_tsv_file, + csv_file, + data_folder, + accented_letters=False, + language="en", +): + """ + Creates the csv file given a list of wav files. + + Arguments + --------- + convert_to_wav : bool + If True, `.mp3` files are converted (duplicated) to uncompressed `.wav`. + Uncompressed `wav`s can be much faster to decode than MP3, at the cost + of much higher disk usage and bandwidth. This might be useful if you are + CPU-limited in workers during training. + This invokes the `ffmpeg` commandline, so ffmpeg must be installed. + orig_tsv_file : str + Path to the Common Voice tsv file (standard file). + csv_file : str + New csv file. + data_folder : str + Path of the CommonVoice dataset. + accented_letters : bool, optional + Defines if accented letters will be kept as individual letters or + transformed to the closest non-accented letters. + language : str + Language code, e.g. "en" + """ + + # Check if the given files exists + if not os.path.isfile(orig_tsv_file): + msg = "\t%s doesn't exist, verify your dataset!" % (orig_tsv_file) + logger.info(msg) + raise FileNotFoundError(msg) + + # We load and skip the header + csv_lines = open(orig_tsv_file, "r", encoding="utf-8").readlines() + header_line = csv_lines[0] + csv_data_lines = csv_lines[1:] + nb_samples = len(csv_data_lines) + + header_map = { + column_name: index + for index, column_name in enumerate(header_line.split("\t")) + } + + msg = "Preparing CSV files for %s samples ..." % (str(nb_samples)) + logger.info(msg) + + # Adding some Prints + msg = "Creating csv lists in %s ..." % (csv_file) + logger.info(msg) + + # Process and write lines + total_duration = 0.0 + + line_processor = functools.partial( + process_line, + convert_to_wav=convert_to_wav, + data_folder=data_folder, + language=language, + accented_letters=accented_letters, + header_map=header_map, + ) + + # Stream into a .tmp file, and rename it to the real path at the end. + csv_file_tmp = csv_file + ".tmp" + + with open(csv_file_tmp, mode="w", newline="", encoding="utf-8") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + + csv_writer.writerow(["ID", "duration", "wav", "spk_id", "wrd"]) + + for row in parallel_map(line_processor, csv_data_lines): + if row is None: + continue + + total_duration += row.duration + csv_writer.writerow( + [ + row.snt_id, + str(row.duration), + row.audio_path, + row.spk_id, + row.words, + ] + ) + + os.replace(csv_file_tmp, csv_file) + + # Final prints + msg = "%s successfully created!" % (csv_file) + logger.info(msg) + msg = "Number of samples: %s " % (str(len(csv_data_lines))) + logger.info(msg) + msg = "Total duration: %s Hours" % (str(round(total_duration / 3600, 2))) + logger.info(msg) + + +def convert_mp3_to_wav(audio_mp3_path): + """Convert an mp3 file to a wav file. + + Parameters + ---------- + audio_mp3_path : str + The path to the opus file to be converted. + + Returns + ------- + str + The path to the converted wav file. + + Raises + ------ + subprocess.CalledProcessError + If the conversion process fails. + """ + audio_wav_path = audio_mp3_path.replace(".mp3", ".wav") + + if VERBOSE: + os.system( + f"ffmpeg -y -i {audio_mp3_path} -ac 1 -ar {SAMPLING_RATE} {audio_wav_path}" + ) + else: + os.system( + f"ffmpeg -y -i {audio_mp3_path} -ac 1 -ar {SAMPLING_RATE} {audio_wav_path} > /dev/null 2>&1" + ) + return audio_wav_path + + +def language_specific_preprocess(language, words): + # !! Language specific cleaning !! + # Important: feel free to specify the text normalization + # corresponding to your alphabet. + + if language in ["en", "fr", "it", "rw"]: + words = re.sub( + "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words + ).upper() + + if language == "de": + # this replacement helps preserve the case of ß + # (and helps retain solitary occurrences of SS) + # since python's upper() converts ß to SS. + words = words.replace("ß", "0000ß0000") + words = re.sub("[^’'A-Za-z0-9öÖäÄüÜß]+", " ", words).upper() + words = words.replace("'", " ") + words = words.replace("’", " ") + words = words.replace( + "0000SS0000", "ß" + ) # replace 0000SS0000 back to ß as its initial presence in the corpus + + elif language == "fr": # SM + words = re.sub("[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words) + words = words.replace("’", "'") + words = words.replace("é", "é") + words = words.replace("æ", "ae") + words = words.replace("œ", "oe") + words = words.replace("â", "â") + words = words.replace("ç", "ç") + words = words.replace("è", "è") + words = words.replace("à", "à") + words = words.replace("û", "û") + words = words.replace("î", "î") + words = words.upper() + + # Case of apostrophe collés + words = words.replace("L'", "L' ") + words = words.replace("L' ", "L' ") + words = words.replace("S'", "S' ") + words = words.replace("S' ", "S' ") + words = words.replace("D'", "D' ") + words = words.replace("D' ", "D' ") + words = words.replace("J'", "J' ") + words = words.replace("J' ", "J' ") + words = words.replace("N'", "N' ") + words = words.replace("N' ", "N' ") + words = words.replace("C'", "C' ") + words = words.replace("C' ", "C' ") + words = words.replace("QU'", "QU' ") + words = words.replace("QU' ", "QU' ") + words = words.replace("M'", "M' ") + words = words.replace("M' ", "M' ") + + # Case of apostrophe qui encadre quelques mots + words = words.replace(" '", " ") + words = words.replace("A'", "A") + words = words.replace("B'", "B") + words = words.replace("E'", "E") + words = words.replace("F'", "F") + words = words.replace("G'", "G") + words = words.replace("K'", "K") + words = words.replace("Q'", "Q") + words = words.replace("V'", "V") + words = words.replace("W'", "W") + words = words.replace("Z'", "Z") + words = words.replace("O'", "O") + words = words.replace("X'", "X") + words = words.replace( + "AUJOURD' HUI", "AUJOURD'HUI" # cspell:disable-line + ) + elif language == "ar": + HAMZA = "\u0621" + ALEF_MADDA = "\u0622" + ALEF_HAMZA_ABOVE = "\u0623" + letters = ( + "ابتةثجحخدذرزژشسصضطظعغفقكلمنهويىءآأؤإئ" # cspell:disable-line + + HAMZA + + ALEF_MADDA + + ALEF_HAMZA_ABOVE + ) + words = re.sub("[^" + letters + " ]+", "", words).upper() + elif language == "fa": + HAMZA = "\u0621" + ALEF_MADDA = "\u0622" + ALEF_HAMZA_ABOVE = "\u0623" + letters = ( + "ابپتةثجحخچدذرزژسشصضطظعغفقگکلمنهویىءآأؤإئ" # cspell:disable-line + + HAMZA + + ALEF_MADDA + + ALEF_HAMZA_ABOVE + ) + words = re.sub("[^" + letters + " ]+", "", words).upper() + elif language == "ga-IE": + # Irish lower() is complicated, but upper() is nondeterministic, so use lowercase + def pfxuc(a): + return len(a) >= 2 and a[0] in "tn" and a[1] in "AEIOUÁÉÍÓÚ" + + def galc(w): + return w.lower() if not pfxuc(w) else w[0] + "-" + w[1:].lower() + + words = re.sub("[^-A-Za-z'ÁÉÍÓÚáéíóú]+", " ", words) + words = " ".join(map(galc, words.split(" "))) + elif language == "es": + # Fix the following error in dataset large: + # KeyError: 'The item En noviembre lanzaron Queen Elizabeth , coproducida por Foreign Noi$e . requires replacements which were not supplied.' + # cspell:ignore noviembre lanzaron coproducida + words = words.replace("$", "s") + return words + + +def check_commonvoice_folders(data_folder): + """ + Check if the data folder actually contains the Common Voice dataset. + If not, raises an error. + + Arguments + --------- + data_folder : str + The folder containing the data to check + + Raises + ------ + FileNotFoundError + If data folder doesn't contain Common Voice dataset. + """ + files_str = "/clips" + # Checking clips + if not os.path.exists(data_folder + files_str): + err_msg = ( + "the folder %s does not exist (it is expected in " + "the Common Voice dataset)" % (data_folder + files_str) + ) + raise FileNotFoundError(err_msg) + + +def unicode_normalisation(text): + return str(text) + + +def strip_accents(text): + text = ( + unicodedata.normalize("NFD", text) + .encode("ascii", "ignore") + .decode("utf-8") + ) + return str(text) \ No newline at end of file diff --git a/benchmarks/DASB/CommonVoice/extraction/extract.py b/benchmarks/DASB/CommonVoice/extraction/extract.py new file mode 100644 index 000000000..a4cfccfd1 --- /dev/null +++ b/benchmarks/DASB/CommonVoice/extraction/extract.py @@ -0,0 +1,107 @@ +#!/usr/bin/env/python3 +"""Recipe for extracting a discrete tokens with librispeech. + +Authors + * Jarod Duret 2024 +""" + +import os +import sys +import logging +import pathlib as pl +import speechbrain as sb +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml + +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) +#sys.path.insert(0, '/data/anakuzne/benchmarks/speechbrain/speechbrain') +print(base_dir) + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing CommonVoice dataset) + from common_voice_prepare import prepare_common_voice + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_common_voice, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "train_tsv_file": hparams["train_tsv_file"], + "dev_tsv_file": hparams["dev_tsv_file"], + "test_tsv_file": hparams["test_tsv_file"], + "accented_letters": hparams["accented_letters"], + "language": hparams["language"], + "skip_prep": hparams["skip_prep"], + }, + ) + + tokens_extractor = hparams["tokens_extractor"] + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], + replacements={"data_root": data_folder}, + ) + + train_data = train_data.filtered_sorted( + sort_key="duration", + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_opts"]["shuffle"] = False + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["dev_csv"], + replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_csv"], + replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + datasets = [train_data, valid_data, test_data] + + merged_dataset = { + key: value + for dataset in datasets + for key, value in dataset.data.items() + } + + save_folder = pl.Path(hparams["save_folder"]) + + logger.info("Extracting dataset tokens ...") + tokens_extractor.extract_tokens( + merged_dataset, + hparams["num_codebooks"], + (save_folder / hparams["language"]).as_posix(), + ) + + if hparams["save_embedding"]: + save_folder = pl.Path(hparams["save_folder"]) + logger.info(f"Saving embeddings ...") + tokens_extractor.save_pretrained_embeddings( + (save_folder / "embeddings").as_posix(), + vocab_size=hparams["vocab_size"], + num_codebooks=hparams["num_codebooks"], + ) diff --git a/benchmarks/DASB/CommonVoice/extraction/hparams/dac.yaml b/benchmarks/DASB/CommonVoice/extraction/hparams/dac.yaml new file mode 100644 index 000000000..eb8b8eeba --- /dev/null +++ b/benchmarks/DASB/CommonVoice/extraction/hparams/dac.yaml @@ -0,0 +1,68 @@ +# ############################################################################ +# Auido Tokenizer: DAC +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/dac +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: /data/anakuzne/cv/cv-corpus-17.0-2024-03-15/cy +train_tsv: !ref /train.tsv +dev_tsv: !ref /dev.tsv +test_tsv: !ref /test.tsv +language: cy +accented_letters: True +skip_prep: False +convert_to_wav: False + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + + +batch_size: 8 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +####################### Model parameters ########################### +# Tokenizer parameters +# DAC parameters +# model_type: [16khz, 24khz, 44khz, 44khz] +# vocab_size: [1024, 1024, 1024, 1024] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +# max_num_codebooks: [12, 32, 9, 18] +# embedding_dim: [1024, 1024, 1024, 128] +model_type: 24khz +vocab_size: 1024 +model_bitrate: 8kbps +num_codebooks: 32 +sample_rate: 24000 +# Feature parameters +encoder_dim: 1024 +save_embedding: False + +tokenizer: !new:utils.tokenizer_interface.DACTokenizer + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/CommonVoice/extraction/hparams/discrete_ssl.yaml b/benchmarks/DASB/CommonVoice/extraction/hparams/discrete_ssl.yaml new file mode 100644 index 000000000..ae9cb618f --- /dev/null +++ b/benchmarks/DASB/CommonVoice/extraction/hparams/discrete_ssl.yaml @@ -0,0 +1,105 @@ +# ############################################################################ +# Auido Tokenizer: WavLM +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/hubert +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: /data/anakuzne/cv/cv-corpus-17.0-2024-03-15/cy +train_tsv: !ref /train.tsv +dev_tsv: !ref /dev.tsv +test_tsv: !ref /test.tsv +language: cy +accented_letters: True +skip_prep: False +convert_to_wav: False + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +batch_size: 1 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +### Configuration for discrete SSL model +# | SSL Model | HF Encoder | K-Means Dataset | K-Means Size | SSL Layers | Vocoder Model | +# |------------|----------------------------------------|-----------------|--------------|----------------------|---------------------------------------------| +# | WavLM | microsoft/wavlm-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wavlm-k1000-LibriTTS | +# | HuBERT | facebook/hubert-large-ll60k | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-hubert-k1000-LibriTTS | +# | Wav2Vec2 | facebook/wav2vec2-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wav2vec2-k1000-LibriTTS | + + +# ssl_model_type: hubert, wavlm, wav2vec2 +# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large +ssl_model_type: hubert +ssl_hub: facebook/hubert-large-ll60k +ssl_folder: !ref /ssl_checkpoint +kmeans_cache_dir: !ref /kmeans_checkpoint +kmeans_dataset: LibriSpeech960 +vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS +freeze_ssl: True +freeze_feature_extractor: True +vocab_size: 1000 +save_embedding: False + +### Config for Tokenizer +# Layer number should be among the supported layers for discrete SSL models(kmenas model should be available for that layer) +num_codebooks: [1, 3, 7, 12, 18, 23] +deduplicate: [False, False, False, False, False, False] +bpe_tokenizer_path: [null, null, null, null, null, null] +sample_rate: 16000 +encoder_dim: 1024 + +ssl_model: !apply:speechbrain.utils.hparams.choice + value: !ref + choices: + WavLM: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + HuBERT: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + Wav2Vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + +tokenizer: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !ref + vocoder_repo_id: !ref + kmeans_dataset: !ref + num_clusters: !ref + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/CommonVoice/extraction/hparams/encodec.yaml b/benchmarks/DASB/CommonVoice/extraction/hparams/encodec.yaml new file mode 100644 index 000000000..707a2f43f --- /dev/null +++ b/benchmarks/DASB/CommonVoice/extraction/hparams/encodec.yaml @@ -0,0 +1,70 @@ +# ############################################################################ +# Auido Tokenizer: Encodec +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/encodec +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: /data/anakuzne/cv/cv-corpus-17.0-2024-03-15/cy +train_tsv_file: !ref /train.tsv +dev_tsv_file: !ref /dev.tsv +test_tsv_file: !ref /test.tsv + +train_csv: !ref /train.csv +dev_csv: !ref /dev.csv +test_csv: !ref /test.csv + +language: cy +accented_letters: True +skip_prep: False +convert_to_wav: False + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +batch_size: 8 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +# num_codebooks: [2, 4, 8, 16, 32] +bandwidth: 24.0 +num_codebooks: 32 +vocab_size: 1024 +sample_rate: 24000 +save_embedding: False + +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/CommonVoice/extraction/hparams/speech_tokenizer.yaml b/benchmarks/DASB/CommonVoice/extraction/hparams/speech_tokenizer.yaml new file mode 100644 index 000000000..a3dcd0bf6 --- /dev/null +++ b/benchmarks/DASB/CommonVoice/extraction/hparams/speech_tokenizer.yaml @@ -0,0 +1,60 @@ +# ############################################################################ +# Auido Tokenizer: Speech Tokenizer +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/speech_tokenizer +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: /data/anakuzne/cv/cv-corpus-17.0-2024-03-15/cy +train_tsv: !ref /train.tsv +dev_tsv: !ref /dev.tsv +test_tsv: !ref /test.tsv +language: cy +accented_letters: True +skip_prep: False +convert_to_wav: False + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +batch_size: 1 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +vocab_size: 1024 +num_codebooks: 8 +sample_rate: 16000 +encoder_dim: 1024 +freeze_embedding: False +save_embedding: False + +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +#tokenizer: !new:utils.tokenizer_interface.SpeechTokenizer +# source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio +# save_path: !ref + +tokenizer: !new:utils.tokenizer_interface.SpeechTokenizerWrapper + source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio + save_path: !ref + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/dac.yaml b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/dac.yaml new file mode 100644 index 000000000..605b772b5 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/dac.yaml @@ -0,0 +1,231 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Auido Tokenizer: DAC +# Encoder: LSTM Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: Pooneh Mousavi 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/dac/LSTM/ +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size: 4 # This works for 2x GPUs with 32GB +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.001 +weight_decay: 0.0005 + + +# Training parameters +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is max_batch_len * n_gpus * gradient_accumulation. +# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. +# Please, set your parameters accordingly. +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +# DAC parameters +# model_type: [16khz, 24khz, 44khz, 44khz] +# vocab_size: [1024, 1024, 1024, 1024] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +# max_num_codebooks: [12, 32, 9, 18] +# embedding_dim: [1024, 1024, 1024, 128] +model_type: 24khz +vocab_size: 1024 +model_bitrate: 8kbps +num_codebooks: 2 +sample_rate: 24000 +# Feature parameters +encoder_dim: 1024 +# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. +pretrain_embeddings: False +freeze_embedding: False + + +# LSTM +activation: !name:torch.nn.Sigmoid +dnn_layers: 2 +dnn_neurons: 1024 +dropout: 0.2 +output_neurons: 31 + +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.DACTokenizer + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: !ref + bidirectional: True + dropout: !ref + hidden_size: !ref + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 +# scheduler: !new:speechbrain.nnet.schedulers.LinearNoamScheduler +# lr_initial: !ref +# n_warmup_steps: 7500 +# n_keep_steps: 36000 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/encodec.yaml b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/encodec.yaml new file mode 100644 index 000000000..f13e3cb53 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/encodec.yaml @@ -0,0 +1,231 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Auido Tokenizer: Encodec +# Encoder: LSTM Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: Pooneh Mousavi 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/enocdec/LSTM/ +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size: 4 # This works for 2x GPUs with 32GB +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.001 +weight_decay: 0.0005 + + +# Training parameters +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is max_batch_len * n_gpus * gradient_accumulation. +# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. +# Please, set your parameters accordingly. +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +# num_codebooks: [2, 4, 8, 16, 32] +vocab_size: 1024 +bandwidth: 1.5 +num_codebooks: 2 +sample_rate: 24000 +# Feature parameters +encoder_dim: 1024 +# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. +pretrain_embeddings: False +freeze_embedding: False + + +# LSTM +activation: !name:torch.nn.Sigmoid +dnn_layers: 2 +dnn_neurons: 1024 +dropout: 0.2 +output_neurons: 31 + +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: !ref + bidirectional: True + dropout: !ref + hidden_size: !ref + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 +# scheduler: !new:speechbrain.nnet.schedulers.LinearNoamScheduler +# lr_initial: !ref +# n_warmup_steps: 7500 +# n_keep_steps: 36000 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/speech_tokenizer.yaml b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/speech_tokenizer.yaml new file mode 100644 index 000000000..d0e9aae5b --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/LSTM/speech_tokenizer.yaml @@ -0,0 +1,221 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Auido Tokenizer: SpeechTokenizer +# Encoder: LSTM Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: Pooneh Mousavi 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/speechtokenizer/LSTM/ +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size: 4 # This works for 2x GPUs with 32GB +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.001 +weight_decay: 0.0005 + + +# Training parameters +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is max_batch_len * n_gpus * gradient_accumulation. +# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. +# Please, set your parameters accordingly. +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +vocab_size: 1024 +num_codebooks: 2 +sample_rate: 16000 +# Feature parameters +encoder_dim: 1024 +# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. +pretrain_embeddings: False +freeze_embedding: False + + +# LSTM +activation: !name:torch.nn.Sigmoid +dnn_layers: 2 +dnn_neurons: 1024 +dropout: 0.2 +output_neurons: 31 + +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.SpeechTokenizer + source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio + save_path: !ref + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: !ref + bidirectional: True + dropout: !ref + hidden_size: !ref + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 +# scheduler: !new:speechbrain.nnet.schedulers.LinearNoamScheduler +# lr_initial: !ref +# n_warmup_steps: 7500 +# n_keep_steps: 36000 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/dac.yaml b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/dac.yaml new file mode 100644 index 000000000..8e73e3601 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/dac.yaml @@ -0,0 +1,224 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Auido Tokenizer: DAC +# Encoder: Contextnet Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: Pooneh Mousavi 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/dac/contextnet/ +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size: 4 # This works for 2x GPUs with 32GB +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.001 +weight_decay: 0.0005 + + +# Training parameters +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is max_batch_len * n_gpus * gradient_accumulation. +# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. +# Please, set your parameters accordingly. +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +# DAC parameters +# model_type: [16khz, 24khz, 44khz, 44khz] +# vocab_size: [1024, 1024, 1024, 1024] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +# max_num_codebooks: [12, 32, 9, 18] +# embedding_dim: [1024, 1024, 1024, 128] +model_type: 24khz +vocab_size: 1024 +model_bitrate: 8kbps +num_codebooks: 2 +sample_rate: 24000 +# Feature parameters +encoder_dim: 1024 +# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. +pretrain_embeddings: False +freeze_embedding: False + + +# LSTM +output_neurons: 31 + +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.DACTokenizer + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.ContextNet.ContextNet + input_shape: [null, null, !ref ] + strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 640 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 +# scheduler: !new:speechbrain.nnet.schedulers.LinearNoamScheduler +# lr_initial: !ref +# n_warmup_steps: 7500 +# n_keep_steps: 36000 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/encodec.yaml b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/encodec.yaml new file mode 100644 index 000000000..4d88a7978 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/encodec.yaml @@ -0,0 +1,221 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Encoder: Contextnet Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: Pooneh Mousavi 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/enocdec/Contexnet/ +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size: 4 # This works for 2x GPUs with 32GB +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.001 +weight_decay: 0.0005 + + +# Training parameters +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is max_batch_len * n_gpus * gradient_accumulation. +# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. +# Please, set your parameters accordingly. +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +# num_codebooks: [2, 4, 8, 16, 32] +vocab_size: 1024 +bandwidth: 1.5 +num_codebooks: 2 +sample_rate: 24000 +# Feature parameters +encoder_dim: 1024 +# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. +pretrain_embeddings: False +freeze_embedding: False + +output_neurons: 31 + +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.ContextNet.ContextNet + input_shape: [null, null, !ref ] + strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 640 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 +# scheduler: !new:speechbrain.nnet.schedulers.LinearNoamScheduler +# lr_initial: !ref +# n_warmup_steps: 7500 +# n_keep_steps: 36000 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/speech_tokenizer.yaml b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/speech_tokenizer.yaml new file mode 100644 index 000000000..7fdbf8d51 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/hparams/contextnet/speech_tokenizer.yaml @@ -0,0 +1,212 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Auido Tokenizer: SpeechTokenizer +# Encoder: Contextnet Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: Pooneh Mousavi 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/speechtokenizer/contextnet/ +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size: 4 # This works for 2x GPUs with 32GB +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.001 +weight_decay: 0.0005 + + +# Training parameters +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is max_batch_len * n_gpus * gradient_accumulation. +# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2. +# Please, set your parameters accordingly. +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +vocab_size: 1024 +num_codebooks: 2 +sample_rate: 16000 +# Feature parameters +encoder_dim: 1024 +# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. +pretrain_embeddings: False +freeze_embedding: False + +output_neurons: 31 + +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.SpeechTokenizer + source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio + save_path: !ref + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.ContextNet.ContextNet + input_shape: [null, null, !ref ] + strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 640 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 +# scheduler: !new:speechbrain.nnet.schedulers.LinearNoamScheduler +# lr_initial: !ref +# n_warmup_steps: 7500 +# n_keep_steps: 36000 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/librispeech_prepare.py b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/librispeech_prepare.py new file mode 120000 index 000000000..a3126ec94 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/librispeech_prepare.py @@ -0,0 +1 @@ +../librispeech_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_dac.py b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py similarity index 52% rename from benchmarks/DASB/LibriSpeech/ASR/LSTM/train_dac.py rename to benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py index 479d6719b..938ce8b96 100644 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_dac.py +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py @@ -10,16 +10,24 @@ import os import sys +import time import torch import torchaudio import logging import speechbrain as sb from speechbrain.utils.distributed import run_on_main, if_main_process +from speechbrain.tokenizers.SentencePiece import SentencePiece from hyperpyyaml import load_hyperpyyaml from pathlib import Path +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) + + logger = logging.getLogger(__name__) +_CACHE = {"size": 0} + # Define training procedure class ASR(sb.Brain): @@ -28,24 +36,56 @@ def compute_forward(self, batch, stage): batch = batch.to(self.device) wavs, wav_lens = batch.sig - # Forward pass - # Feature extraction and attention pooling - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens, _ = self.hparams.codec( - wavs.unsqueeze(1), n_quantizers=self.hparams.num_codebooks - ) - embeddings = self.modules.discrete_embedding_layer( - tokens.movedim(-2, -1) - ) - att_w = self.modules.attention_mlp(embeddings) - feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - y = y[0] # As it is an RNN output - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) + # Add waveform augmentation if specified. + if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"): + wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens) # [B, T] + + # compute features + # Extract tokens (cache them at first epoch if augmentation is disabled) + key = tuple(sorted(batch.id)) + try: + in_toks = _CACHE[key] + in_toks = in_toks.to(self.device) + except KeyError: + with torch.no_grad(): + self.hparams.tokenizer.eval().to(self.device) + in_toks = self.hparams.tokenizer.sig_to_tokens( + wavs, wav_lens, num_codebooks=hparams["num_codebooks"] + ) # [B, T, N-Q] + if stage != sb.Stage.TRAIN or ( + stage == sb.Stage.TRAIN + and (not hasattr(self.hparams, "wav_augment")) + ): + if _CACHE["size"] < self.hparams.cache_size: + _CACHE[key] = in_toks.cpu() + _CACHE["size"] += in_toks.numel() + + # Extract embeddings + in_embs = self.modules.discrete_embedding_layer( + in_toks + ) # [B, T, N-Q, D] + + # Attention-Pooling + att_w = self.modules.attention_mlp(in_embs) # [B, T, N-Q, 1] + in_embs = torch.matmul(att_w.transpose(2, -1), in_embs).squeeze( + -2 + ) # [B, T, D] + + # forward modules + if type(self.modules.encoder).__name__ == "ContextNet": + enc_out = self.modules.encoder(in_embs) + + elif type(self.modules.encoder).__name__ == "LSTM": + enc_out, _ = self.modules.encoder(in_embs) + + else: + raise NotImplementedError + + # output layer for ctc log-probabilities + logits = self.modules.ctc_lin(enc_out) p_ctc = self.hparams.log_softmax(logits) + + p_tokens = None if stage == sb.Stage.VALID: p_tokens = sb.decoders.ctc_greedy_decode( p_ctc, wav_lens, blank_id=self.hparams.blank_index @@ -61,14 +101,19 @@ def compute_objectives(self, predictions, batch, stage): p_ctc, wav_lens, predicted_tokens = predictions ids = batch.id tokens, tokens_lens = batch.tokens + + # Label Augmentation + if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"): + tokens = self.hparams.wav_augment.replicate_labels(tokens) + tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens) + loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) if stage == sb.Stage.VALID: # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] + predicted_words = self.tokenizer( + predicted_tokens, task="decode_from_list" + ) elif stage == sb.Stage.TEST: predicted_words = [ hyp[0].text.split(" ") for hyp in predicted_tokens @@ -85,10 +130,10 @@ def on_stage_start(self, stage, epoch): """Gets called at the beginning of each epoch""" if stage != sb.Stage.TRAIN: self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() + self.wer_metric = self.hparams.wer_computer() def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" + """Gets called at the end of a epoch.""" # Compute/store important stats stage_stats = {"loss": stage_loss} if stage == sb.Stage.TRAIN: @@ -96,60 +141,61 @@ def on_stage_end(self, stage, stage_loss, epoch): else: stage_stats["CER"] = self.cer_metric.summarize("error_rate") stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + if ( + current_epoch % valid_search_interval == 0 + or stage == sb.Stage.TEST + ): + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # log stats and save checkpoint at end-of-epoch if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - # old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( - # stage_stats["loss"] - # ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - # sb.nnet.schedulers.update_learning_rate( - # self.weights_optimizer, new_lr_weights - # ) - + if type(self.hparams.scheduler).__name__ == "NewBobScheduler": + lr, new_lr = self.hparams.scheduler(stage_stats["loss"]) + sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) + elif type(self.hparams.scheduler).__name__ == "LinearNoamScheduler": + lr = self.hparams.scheduler.current_lr + else: + raise NotImplementedError + + optimizer = self.optimizer.__class__.__name__ + epoch_stats = { + "epoch": epoch, + "lr": lr, + "optimizer": optimizer, + } self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, + stats_meta=epoch_stats, train_stats=self.train_stats, valid_stats=stage_stats, ) self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + meta={"WER": stage_stats["WER"], "epoch": epoch}, + min_keys=["WER"], + num_to_keep=self.hparams.avg_checkpoints, ) + elif stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stage_stats, ) if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: + with open( + self.hparams.output_wer_folder, "w", encoding="utf-8" + ) as w: self.wer_metric.write_stats(w) - def init_optimizers(self): - # "Initializes the weights optimizer and model optimizer" - # self.weights_optimizer = self.hparams.weights_opt_class( - # self.hparams.attention_mlp.parameters() - # ) - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - # "weights_optimizer": self.weights_optimizer, - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - # self.checkpointer.add_recoverable( - # "weights_opt", self.weights_optimizer - # ) + def on_fit_batch_end(self, batch, outputs, loss, should_step): + if ( + should_step + and type(self.hparams.scheduler).__name__ == "LinearNoamScheduler" + ): + self.hparams.scheduler(self.optimizer) -def dataio_prepare(hparams): +def dataio_prepare(hparams, tokenizer): """This function prepares the datasets to be used in the brain class. It also defines the data processing pipeline through user-defined functions.""" data_folder = hparams["data_folder"] @@ -206,11 +252,10 @@ def audio_pipeline(wav): resampled = torchaudio.transforms.Resample( info.sample_rate, hparams["sample_rate"], )(sig) - # resampled = resampled.unsqueeze(0) + # resampled = resampled.unsqueeze(0) return resampled sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() # 3. Define text pipeline: @sb.utils.data_pipeline.takes("wrd") @@ -221,45 +266,59 @@ def text_pipeline(wrd): yield wrd char_list = list(wrd) yield char_list - tokens_list = label_encoder.encode_sequence(char_list) + tokens_list = tokenizer.sp.encode_as_ids(wrd) yield tokens_list tokens = torch.LongTensor(tokens_list) yield tokens sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - # 4. Set output: sb.dataio.dataset.set_output_keys( datasets, ["id", "sig", "wrd", "char_list", "tokens"], ) - return train_data, valid_data, test_datasets, label_encoder + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from speechbrain.dataio.sampler import DynamicBatchSampler # noqa + + dynamic_hparams_train = hparams["dynamic_batch_sampler_train"] + dynamic_hparams_val = hparams["dynamic_batch_sampler_val"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_train, + ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_val, + ) + + return ( + train_data, + valid_data, + test_datasets, + train_batch_sampler, + valid_batch_sampler, + ) if __name__ == "__main__": # CLI: hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) # If distributed_launch=True then # create ddp_group with the right communication protocol sb.utils.distributed.ddp_init_group(run_opts) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - # Create experiment directory sb.create_experiment_directory( experiment_directory=hparams["output_folder"], @@ -285,25 +344,68 @@ def text_pipeline(wrd): }, ) + # Defining tokenizer and loading it + tokenizer = SentencePiece( + model_dir=hparams["save_folder"], + vocab_size=hparams["output_neurons"], + annotation_train=hparams["train_csv"], + annotation_read="wrd", + model_type=hparams["token_type"], + character_coverage=hparams["character_coverage"], + bos_id=hparams["bos_index"], + eos_id=hparams["eos_index"], + ) + # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams + ( + train_data, + valid_data, + test_datasets, + train_bsampler, + valid_bsampler, + ) = dataio_prepare(hparams, tokenizer) + + # Use pretrained embeddings + if hparams["pretrain_embeddings"]: + embs = hparams["tokenizer"].get_pretrained_embeddings( + device=run_opts["device"], + num_codebooks=hparams["num_codebooks"], + vocab_size=hparams["vocab_size"], + ) + hparams["discrete_embedding_layer"].init_embedding(embs) + + # Log number of parameters/buffers + codec_params = sum( + [x.numel() for x in hparams["tokenizer"].state_dict().values()] + ) + model_params = sum( + [ + x.numel() + for module in hparams["modules"].values() + for x in module.state_dict().values() + ] + ) + hparams["train_logger"].log_stats( + stats_meta={ + f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", + }, ) # Trainer initialization asr_brain = ASR( modules=hparams["modules"], + opt_class=hparams["model_opt_class"], hparams=hparams, run_opts=run_opts, checkpointer=hparams["checkpointer"], ) - # Loading the SSL model - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] + # Adding objects to trainer. + asr_brain.tokenizer = tokenizer + vocab_list = [ + tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size()) + ] from speechbrain.decoders.ctc import CTCBeamSearcher @@ -311,6 +413,20 @@ def text_pipeline(wrd): **hparams["test_beam_search"], vocab_list=vocab_list, ) + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + # Measure time + start_time = time.time() # Start the timer + # Training asr_brain.fit( asr_brain.hparams.epoch_counter, @@ -320,12 +436,19 @@ def text_pipeline(wrd): valid_loader_kwargs=hparams["valid_dataloader_opts"], ) + end_time = time.time() # End the timer + # Calculate elapsed time + elapsed_time = end_time - start_time + logger.info(f"Model execution time: {elapsed_time:.6f} seconds") + # hparams["train_logger"].log_stats( + # stats_meta={f"Model execution time: {elapsed_time:.6f} seconds"}, + # ) # Testing if not os.path.exists(hparams["output_wer_folder"]): os.makedirs(hparams["output_wer_folder"]) for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( + asr_brain.hparams.output_wer_folder = os.path.join( hparams["output_wer_folder"], f"wer_{k}.txt" ) asr_brain.evaluate( diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/custom_model.py b/benchmarks/DASB/LibriSpeech/ASR/LSTM/custom_model.py deleted file mode 120000 index 4b3f08ebb..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/custom_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../model/custom_model.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_dac.yaml b/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_dac.yaml deleted file mode 100644 index 0b00db1f7..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_dac.yaml +++ /dev/null @@ -1,178 +0,0 @@ -# ################################ -# Recipe for training an discrete-input ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * Pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-LSTM/dac/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] - -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - - -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - - -### Config for Tokenizer -# DAC parameters -# model_type: [16khz, 24khz, 44khz, 44khz] -# vocab_size: [1024, 1024, 1024, 1024] -# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] -# max_num_codebooks: [12, 32, 9, 18] -# embedding_dim: [1024, 1024, 1024, 128] -model_type: 24khz -vocab_size: 1024 -model_bitrate: 8kbps -num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type -sample_rate: 24000 -encoder_dim: 1024 - - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 768 -freeze_encoder: True - -# Outputs -output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -# -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -# Modules -# DAC model (see https://github.com/descriptinc/descript-audio-codec) -codec: !new:speechbrain.lobes.models.discrete.dac.DAC - model_type: !ref - model_bitrate: !ref - load_pretrained: True - tag: latest - -discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -enc: !new:speechbrain.nnet.RNN.LSTM - input_shape: [Null, Null, !ref ] - num_layers: 2 - bidirectional: True - dropout: 0.2 - hidden_size: 1024 - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 2048 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - scheduler_model: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_discrete_ssl.yaml b/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_discrete_ssl.yaml deleted file mode 100644 index c5a920693..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_discrete_ssl.yaml +++ /dev/null @@ -1,216 +0,0 @@ -# ################################ -# Recipe for training an discrete-input ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * Pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-LSTM/discrete_ssl/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] - -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - -### Configuration for discrete SSL model -# ssl_model_type: hubert, wavlm, wav2vec2 -# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large -ssl_model_type: hubert # hubert, wavml or wav2vec2 -ssl_hub: facebook/hubert-large-ll60k -ssl_folder: !ref /ssl_checkpoint -kmeans_repo_id: speechbrain/SSL_Quantization -kmeans_cache_dir: !ref /kmeans_checkpoint -kmeans_dataset: LibriSpeech-100-360-500 -freeze_ssl: True -freeze_feature_extractor: True -num_clusters: 1000 - -### Config for Tokenizer -# Layer number should be among the supported layers for discrete SSL models(kmenas model should be available for that layer) -# ssl_layer_num: [3, 7, 12, 23] -# deduplicate: [False, False, False, False] -# bpe_tokenizer_path: [null , null, null, null] -ssl_layer_num: [1, 3, 7, 12, 18, 23] -num_codebooks: 6 -deduplicate: [False, False, False, False, False, False] -bpe_tokenizer_path: [null, null, null, null, null, null] -sample_rate: 16000 -encoder_dim: 1024 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 1024 -freeze_encoder: True - -# Outputs -output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -# -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -tokenizer_config: - SSL_layers: !ref - deduplicates: !ref - bpe_tokenizers: !ref - -ssl_model: !apply:speechbrain.utils.hparams.choice - value: !ref - choices: - wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: !ref - output_norm: False - freeze: !ref - freeze_feature_extractor: !ref - output_all_hiddens: True - save_path: !ref - hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: !ref - output_norm: False - freeze: !ref - freeze_feature_extractor: !ref - output_all_hiddens: True - save_path: !ref - wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: !ref - output_norm: False - freeze: !ref - freeze_feature_extractor: !ref - output_all_hiddens: True - save_path: !ref - -codec: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL - save_path: !ref - ssl_model: !ref - kmeans_dataset: !ref - kmeans_repo_id: !ref - num_clusters: !ref - -discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -enc: !new:speechbrain.nnet.RNN.LSTM - input_shape: [Null, Null, !ref ] - num_layers: 2 - bidirectional: True - dropout: 0.2 - hidden_size: 1024 - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 2048 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - scheduler_model: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_encodec.yaml b/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_encodec.yaml deleted file mode 100644 index e2477819a..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_encodec.yaml +++ /dev/null @@ -1,183 +0,0 @@ -# ################################ -# Recipe for training an discrete-input ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * Pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-LSTM/encodec/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - - -# Data files -data_folder: data # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] - -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - - -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - - -### Config for Tokenizer -# EnCodec parameters -# sample_rate: [24000, 24000, 24000, 24000] -# vocab_size: [1024, 1024, 1024, 1024] -# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] -# num_codebooks: [2, 4, 8, 16, 32] -vocab_size: 1024 -bandwidth: 1.5 -num_codebooks: 2 -sample_rate: 24000 -# Feature parameters -encoder_dim: 1024 -# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. -init_embedding: False -freeze_embedding: False - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 1024 -freeze_encoder: True - -# Outputs -output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -# -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec - source: facebook/encodec_24khz # Only the 24kHz version supports mono audio - save_path: !ref - sample_rate: !ref - bandwidth: !ref - flat_embeddings: False - freeze: True - renorm_embeddings: False - -discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - freeze: !ref - init: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -enc: !new:speechbrain.nnet.RNN.LSTM - input_shape: [Null, Null, !ref ] - num_layers: 2 - bidirectional: True - dropout: 0.2 - hidden_size: 1024 - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 2048 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - scheduler_model: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_weighted_ssl.yaml b/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_weighted_ssl.yaml deleted file mode 100644 index bcfbe8d50..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/hparams/train_weighted_ssl.yaml +++ /dev/null @@ -1,162 +0,0 @@ -# ################################ -# Recipe for training an SSL-based ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * Salah Zaiem 2023 -# * Youcef Kemiche 2023 -# * Pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-LSTM/weighted_ssl/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] - -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -ssl_hub: microsoft/wavlm-large -ssl_folder: !ref /ssl_checkpoints -encoder_dim: 1024 - -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -lr_weights: 0.01 -sorting: ascending -precision: fp32 -sample_rate: 16000 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 768 -freeze_encoder: True - -# Outputs -output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -# -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -weighted_ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.WeightedSSLModel # yamllint disable-line rule:line-length - hub: !ref - save_path: !ref - -enc: !new:speechbrain.nnet.RNN.LSTM - input_shape: [Null, Null, !ref ] - num_layers: 2 - bidirectional: True - dropout: 0.2 - hidden_size: 1024 - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 2048 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - weighted_ssl_model: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -weights_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.9 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - ssl_model: !ref - scheduler_model: !ref - scheduler_encoder: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/librispeech_prepare.py b/benchmarks/DASB/LibriSpeech/ASR/LSTM/librispeech_prepare.py deleted file mode 120000 index cf4adfd79..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/librispeech_prepare.py +++ /dev/null @@ -1 +0,0 @@ -../../librispeech_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_discrete_ssl.py b/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_discrete_ssl.py deleted file mode 100644 index 2aac19193..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_discrete_ssl.py +++ /dev/null @@ -1,333 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an discrete tokens ctc ASR system with librispeech. - -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Pooneh Mousavi 2024 -""" - -import os -import sys -import torch -import torchaudio -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - # Feature extraction and attention pooling - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens, _, _ = self.hparams.codec( - wavs, wav_lens, **self.hparams.tokenizer_config - ) - embeddings = self.modules.discrete_embedding_layer(tokens) - att_w = self.modules.attention_mlp(embeddings) - feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - y = y[0] # As it is an RNN output - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - # old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( - # stage_stats["loss"] - # ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - # sb.nnet.schedulers.update_learning_rate( - # self.weights_optimizer, new_lr_weights - # ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - # "Initializes the weights optimizer and model optimizer" - # self.weights_optimizer = self.hparams.weights_opt_class( - # self.hparams.attention_mlp.parameters() - # ) - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - # "weights_optimizer": self.weights_optimizer, - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - # self.checkpointer.add_recoverable( - # "weights_opt", self.weights_optimizer - # ) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - info = torchaudio.info(wav) - resampled = torchaudio.transforms.Resample( - info.sample_rate, hparams["sample_rate"], - )(sig) - # resampled = resampled.unsqueeze(0) - return resampled - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Loading the SSL model - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_speech_tokenizer.py b/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_speech_tokenizer.py deleted file mode 100644 index 1493b5972..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_speech_tokenizer.py +++ /dev/null @@ -1,335 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an SSL-based ctc ASR system with librispeech. - -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Adel Moumen 2024 - * Salah Zaiem 2023 - * Youcef Kemiche 2023 -""" - -import os -import sys -import torch -import torchaudio -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - # Feature extraction and attention pooling - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens = self.hparams.codec(wavs).permute(1, 2, 0)[ - :, :, : self.hparams.num_codebooks - ] - embeddings = self.modules.discrete_embedding_layer(tokens) - att_w = self.modules.attention_mlp(embeddings) - feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - y = y[0] # As it is an RNN output - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - # old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( - # stage_stats["loss"] - # ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - # sb.nnet.schedulers.update_learning_rate( - # self.weights_optimizer, new_lr_weights - # ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - "Initializes the weights optimizer and model optimizer" - # self.weights_optimizer = self.hparams.weights_opt_class( - # self.hparams.attention_mlp.parameters() - # ) - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - # "weights_optimizer": self.weights_optimizer, - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - # self.checkpointer.add_recoverable( - # "weights_opt", self.weights_optimizer - # ) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - info = torchaudio.info(wav) - resampled = torchaudio.transforms.Resample( - info.sample_rate, hparams["sample_rate"], - )(sig) - # resampled = resampled.unsqueeze(0) - return resampled - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Loading the SSL model - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_weighted_ssl.py b/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_weighted_ssl.py deleted file mode 100644 index 4a7aed382..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/LSTM/train_weighted_ssl.py +++ /dev/null @@ -1,322 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an SSL-based ctc ASR system with librispeech. - -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Adel Moumen 2024 - * Salah Zaiem 2023 - * Youcef Kemiche 2023 - * Pooneh Mousavi 2024 -""" - -import os -import sys -import torch -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - feats = self.modules.weighted_ssl_model(wavs) - y = self.modules.enc(feats) - y = y[0] # As it is an RNN output - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( - stage_stats["loss"] - ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - sb.nnet.schedulers.update_learning_rate( - self.weights_optimizer, new_lr_weights - ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - "Initializes the weights optimizer and model optimizer" - self.weights_optimizer = self.hparams.weights_opt_class( - [self.modules.weighted_ssl_model.weights] - ) - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - "weights_optimizer": self.weights_optimizer, - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - self.checkpointer.add_recoverable( - "weights_opt", self.weights_optimizer - ) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - return sig - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Loading the SSL model - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/custom_model.py b/benchmarks/DASB/LibriSpeech/ASR/contextnet/custom_model.py deleted file mode 120000 index 4b3f08ebb..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/custom_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../model/custom_model.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_dac.yaml b/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_dac.yaml deleted file mode 100644 index 4533e2e8d..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_dac.yaml +++ /dev/null @@ -1,172 +0,0 @@ -# ################################ -# Recipe for training an dac-based ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-contextnet/dac/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -num_layers_ssl: 25 #Number of layers in the SSL model (should be 25 for large) -### Config for Tokenizer -# DAC parameters -# model_type: [16khz, 24khz, 44khz, 44khz] -# vocab_size: [1024, 1024, 1024, 1024] -# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] -# max_num_codebooks: [12, 32, 9, 18] -# embedding_dim: [1024, 1024, 1024, 128] -model_type: 24khz -vocab_size: 1024 -model_bitrate: 8kbps -num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type -sample_rate: 24000 -encoder_dim: 1024 - - -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 640 -freeze_encoder: True - -# Outputs -output_neurons: 30 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -codec: !new:speechbrain.lobes.models.discrete.dac.DAC - model_type: !ref - model_bitrate: !ref - load_pretrained: True - tag: latest - -discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -enc: !new:speechbrain.lobes.models.ContextNet.ContextNet - input_shape: [null, null, !ref ] - strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - -# only unitary strides to keep the frame rate - - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 640 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - scheduler_model: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_discrete_ssl.yaml b/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_discrete_ssl.yaml deleted file mode 100644 index c394c73c1..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_discrete_ssl.yaml +++ /dev/null @@ -1,214 +0,0 @@ -# ################################ -# Recipe for training an discrete_ssl-based ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-contextnet/encodec/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -num_layers_ssl: 25 #Number of layers in the SSL model (should be 25 for large) - -### Configuration for discrete SSL model -# ssl_model_type: hubert, wavlm, wav2vec2 -# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large -ssl_model_type: hubert # hubert, wavml or wav2vec2 -ssl_hub: facebook/hubert-large-ll60k -ssl_folder: !ref /ssl_checkpoint -kmeans_repo_id: speechbrain/SSL_Quantization -kmeans_cache_dir: !ref /kmeans_checkpoint -kmeans_dataset: LibriSpeech-100-360-500 -freeze_ssl: True -freeze_feature_extractor: True -num_clusters: 1000 - -### Config for Tokenizer -# Layer number should be among the supported layers for discrete SSL models(kmenas model should be available for that layer) -# ssl_layer_num: [3, 7, 12, 23] -# deduplicate: [False, False, False, False] -# bpe_tokenizer_path: [null , null, null, null] -ssl_layer_num: [1, 3, 7, 12, 18, 23] -num_codebooks: 6 -deduplicate: [False, False, False, False, False, False] -bpe_tokenizer_path: [null, null, null, null, null, null] -sample_rate: 16000 -encoder_dim: 1024 - -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 640 -freeze_encoder: True - -# Outputs -output_neurons: 30 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -tokenizer_config: - SSL_layers: !ref - deduplicates: !ref - bpe_tokenizers: !ref - -ssl_model: !apply:speechbrain.utils.hparams.choice - value: !ref - choices: - wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: !ref - output_norm: False - freeze: !ref - freeze_feature_extractor: !ref - output_all_hiddens: True - save_path: !ref - hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: !ref - output_norm: False - freeze: !ref - freeze_feature_extractor: !ref - output_all_hiddens: True - save_path: !ref - wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: !ref - output_norm: False - freeze: !ref - freeze_feature_extractor: !ref - output_all_hiddens: True - save_path: !ref - -codec: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL - save_path: !ref - ssl_model: !ref - kmeans_dataset: !ref - kmeans_repo_id: !ref - num_clusters: !ref - -discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -enc: !new:speechbrain.lobes.models.ContextNet.ContextNet - input_shape: [null, null, !ref ] - strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - -# only unitary strides to keep the frame rate - - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 640 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - scheduler_model: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_encodec.yaml b/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_encodec.yaml deleted file mode 100644 index 6163550e9..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_encodec.yaml +++ /dev/null @@ -1,178 +0,0 @@ -# ################################ -# Recipe for training an encodec-based ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-contextnet/encodec/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -num_layers_ssl: 25 #Number of layers in the SSL model (should be 25 for large) -### Config for Tokenizer -# EnCodec parameters -# sample_rate: [24000, 24000, 24000, 24000] -# vocab_size: [1024, 1024, 1024, 1024] -# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] -# num_codebooks: [2, 4, 8, 16, 32] -vocab_size: 1024 -bandwidth: 1.5 -num_codebooks: 2 -sample_rate: 24000 -# Feature parameters -encoder_dim: 1024 -# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128. -init_embedding: False -freeze_embedding: False - -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 640 -freeze_encoder: True - -# Outputs -output_neurons: 30 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec - source: facebook/encodec_24khz # Only the 24kHz version supports mono audio - save_path: !ref - sample_rate: !ref - bandwidth: !ref - flat_embeddings: False - freeze: True - renorm_embeddings: False - -discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - freeze: !ref - init: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -enc: !new:speechbrain.lobes.models.ContextNet.ContextNet - input_shape: [null, null, !ref ] - strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - -# only unitary strides to keep the frame rate - - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 640 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - scheduler_model: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_speech_tokenizer.yaml b/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_speech_tokenizer.yaml deleted file mode 100644 index aef1307ec..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_speech_tokenizer.yaml +++ /dev/null @@ -1,160 +0,0 @@ -# ################################ -# Recipe for training an speech_tokenizer-based ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-contextnet/speech_tokenizer/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -num_layers_ssl: 25 #Number of layers in the SSL model (should be 25 for large) -### Config for Tokenizer -vocab_size: 1024 -num_codebooks: 2 -sample_rate: 16000 - -encoder_dim: 1024 -# Training parameters -number_of_epochs: 20 -lr: 0.0002 -sorting: ascending -precision: fp32 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 640 -freeze_encoder: True - -# Outputs -output_neurons: 30 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) -codec: !new:speechbrain.lobes.models.discrete.speechtokenizer_interface.SpeechTokenizer_interface - source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio - save_path: !ref -discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -enc: !new:speechbrain.lobes.models.ContextNet.ContextNet - input_shape: [null, null, !ref ] - strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - -# only unitary strides to keep the frame rate - - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 640 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - attention_mlp: !ref - codec: !ref - discrete_embedding_layer: !ref - scheduler_model: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_weighted_ssl.yaml b/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_weighted_ssl.yaml deleted file mode 100644 index 6d806f0a5..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_weighted_ssl.yaml +++ /dev/null @@ -1,157 +0,0 @@ -# ################################ -# Recipe for training an encodec-based ctc ASR system with librispeech. -# Decoding is performed with ctc greedy or LM-rescored decoder. -# -# Authors -# * pooneh Mousavi 2024 -# ################################ - -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1986 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/MP3S-contextnet/encodec/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech -# noise/ris dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] -skip_prep: False -ckpt_interval_minutes: 25 # save checkpoint every N min -train_csv: !ref /train-clean-100.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -num_layers_ssl: 25 #Number of layers in the SSL model (should be 25 for large) -ssl_hub: microsoft/wavlm-large -ssl_folder: !ref /ssl_checkpoints -encoder_dim: 1024 - -# Training parameters -number_of_epochs: 2 -lr: 0.0002 -lr_weights: 0.01 -sorting: ascending -precision: fp32 -sample_rate: 16000 - -# With data_parallel batch_size is split into N jobs -# With DDP batch_size is multiplied by N jobs -# Must be 3 per GPU to fit 32GB of VRAM -batch_size: 4 -test_batch_size: 1 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -# Model parameters -activation: !name:torch.nn.Sigmoid -dnn_layers: 1 -dnn_neurons: 640 -freeze_encoder: True - -# Outputs -output_neurons: 30 - -# Decoding parameters -blank_index: 0 -unk_index: 1 - -test_beam_search: - beam_size: 143 - topk: 1 - blank_index: !ref - space_token: ' ' # make sure this is the same as the one used in the tokenizer - beam_prune_logp: -12.0 - token_prune_min_logp: -1.2 - prune_history: True - alpha: 0.8 - beta: 1.2 - # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM - # It can either be a .bin or .arpa ; note: .arpa is much slower at loading - # If you don't want to use an LM, comment it out or set it to null - kenlm_model_path: null - -# Functions and classes -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -weighted_ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.WeightedSSLModel # yamllint disable-line rule:line-length - hub: !ref - save_path: !ref - -enc: !new:speechbrain.lobes.models.ContextNet.ContextNet - input_shape: [null, null, !ref ] - strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - -# only unitary strides to keep the frame rate - - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: 640 - n_neurons: !ref - -log_softmax: !new:speechbrain.nnet.activations.Softmax - apply_log: True - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - -modules: - enc: !ref - ctc_lin: !ref - weighted_ssl_model: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref ] - -model_opt_class: !name:torch.optim.Adam - lr: !ref - -weights_opt_class: !name:torch.optim.Adam - lr: !ref - -lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.8 - patient: 0 - -lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: 0.0025 - annealing_factor: 0.9 - patient: 0 - -label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - ssl_model: !ref - scheduler_model: !ref - scheduler_encoder: !ref - counter: !ref - tokenizer: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/librispeech_prepare.py b/benchmarks/DASB/LibriSpeech/ASR/contextnet/librispeech_prepare.py deleted file mode 120000 index cf4adfd79..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/librispeech_prepare.py +++ /dev/null @@ -1 +0,0 @@ -../../librispeech_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_dac.py b/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_dac.py deleted file mode 100644 index a177e48a5..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_dac.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an discrete tokens + ctc ASR system with librispeech. -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Pooneh Mousavi 2024 -""" - -import os -import sys -import torch -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path -import torchaudio - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - # Feature extraction and attention pooling - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens, _ = self.hparams.codec( - wavs.unsqueeze(1), n_quantizers=self.hparams.num_codebooks - ) - embeddings = self.modules.discrete_embedding_layer( - tokens.movedim(-2, -1) - ) - att_w = self.modules.attention_mlp(embeddings) - feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - "Initializes the model optimizer" - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - info = torchaudio.info(wav) - resampled = torchaudio.transforms.Resample( - info.sample_rate, hparams["sample_rate"], - )(sig) - # resampled = resampled.unsqueeze(0) - return resampled - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_discrete_ssl.py b/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_discrete_ssl.py deleted file mode 100644 index 640f6a220..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_discrete_ssl.py +++ /dev/null @@ -1,319 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an discrete tokens + ctc ASR system with librispeech. -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Pooneh Mousavi 2024 -""" - -import os -import sys -import torch -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path -import torchaudio - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - # Feature extraction and attention pooling - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens, _, _ = self.hparams.codec( - wavs, wav_lens, **self.hparams.tokenizer_config - ) - embeddings = self.modules.discrete_embedding_layer(tokens) - att_w = self.modules.attention_mlp(embeddings) - feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - "Initializes the model optimizer" - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - info = torchaudio.info(wav) - resampled = torchaudio.transforms.Resample( - info.sample_rate, hparams["sample_rate"], - )(sig) - # resampled = resampled.unsqueeze(0) - return resampled - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_encodec.py b/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_encodec.py deleted file mode 100644 index eb7232303..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_encodec.py +++ /dev/null @@ -1,316 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an discrete tokens + ctc ASR system with librispeech. -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Pooneh Mousavi 2024 -""" - -import os -import sys -import torch -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path -import torchaudio - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens, _ = self.hparams.codec.encode(wavs, wav_lens) - embeddings = self.modules.discrete_embedding_layer(tokens) - att_w = self.modules.attention_mlp(embeddings) - feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - "Initializes the model optimizer" - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - info = torchaudio.info(wav) - resampled = torchaudio.transforms.Resample( - info.sample_rate, hparams["sample_rate"], - )(sig) - # resampled = resampled.unsqueeze(0) - return resampled - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_speech_tokenizer.py b/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_speech_tokenizer.py deleted file mode 100644 index cd784c80c..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_speech_tokenizer.py +++ /dev/null @@ -1,319 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an discrete tokens + ctc ASR system with librispeech. -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Pooneh Mousavi 2024 -""" - -import os -import sys -import torch -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path -import torchaudio - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - # Feature extraction and attention pooling - with torch.no_grad(): - self.hparams.codec.to(self.device).eval() - tokens = self.hparams.codec(wavs).permute(1, 2, 0)[ - :, :, : self.hparams.num_codebooks - ] - embeddings = self.modules.discrete_embedding_layer(tokens) - att_w = self.modules.attention_mlp(embeddings) - feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) - y = self.modules.enc(feats) - - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - "Initializes the model optimizer" - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - info = torchaudio.info(wav) - resampled = torchaudio.transforms.Resample( - info.sample_rate, hparams["sample_rate"], - )(sig) - # resampled = resampled.unsqueeze(0) - return resampled - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_weighted_ssl.py b/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_weighted_ssl.py deleted file mode 100644 index 6d053fceb..000000000 --- a/benchmarks/DASB/LibriSpeech/ASR/contextnet/train_weighted_ssl.py +++ /dev/null @@ -1,318 +0,0 @@ -#!/usr/bin/env/python3 -"""Recipe for training an SSL-based ctc ASR system with librispeech. -Decoding is performed with greedy decoding at validation time. -At test time, beamsearch is used with an optional external language model. - -Authors - * Pooneh Mousavi 2024 -""" - -import os -import sys -import torch -import logging -import speechbrain as sb -from speechbrain.utils.distributed import run_on_main, if_main_process -from hyperpyyaml import load_hyperpyyaml -from pathlib import Path - -logger = logging.getLogger(__name__) - - -# Define training procedure -class ASR(sb.Brain): - def compute_forward(self, batch, stage): - """Forward computations from the waveform batches to the output probabilities.""" - batch = batch.to(self.device) - wavs, wav_lens = batch.sig - - # Forward pass - feats = self.modules.weighted_ssl_model(wavs) - y = self.modules.enc(feats) - - # Compute outputs - p_tokens = None - logits = self.modules.ctc_lin(y) - p_ctc = self.hparams.log_softmax(logits) - - if stage == sb.Stage.VALID: - p_tokens = sb.decoders.ctc_greedy_decode( - p_ctc, wav_lens, blank_id=self.hparams.blank_index - ) - elif stage == sb.Stage.TEST: - p_tokens = test_searcher(p_ctc, wav_lens) - - return p_ctc, wav_lens, p_tokens - - def compute_objectives(self, predictions, batch, stage): - """Computes the loss (CTC+NLL) given predictions and targets.""" - - p_ctc, wav_lens, predicted_tokens = predictions - ids = batch.id - tokens, tokens_lens = batch.tokens - loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) - - if stage == sb.Stage.VALID: - # Decode token terms to words - predicted_words = [ - "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") - for utt_seq in predicted_tokens - ] - elif stage == sb.Stage.TEST: - predicted_words = [ - hyp[0].text.split(" ") for hyp in predicted_tokens - ] - - if stage != sb.Stage.TRAIN: - target_words = [wrd.split(" ") for wrd in batch.wrd] - self.wer_metric.append(ids, predicted_words, target_words) - self.cer_metric.append(ids, predicted_words, target_words) - - return loss - - def on_stage_start(self, stage, epoch): - """Gets called at the beginning of each epoch""" - if stage != sb.Stage.TRAIN: - self.cer_metric = self.hparams.cer_computer() - self.wer_metric = self.hparams.error_rate_computer() - - def on_stage_end(self, stage, stage_loss, epoch): - """Gets called at the end of an epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["CER"] = self.cer_metric.summarize("error_rate") - stage_stats["WER"] = self.wer_metric.summarize("error_rate") - - # Perform end-of-iteration things, like annealing, logging, etc. - if stage == sb.Stage.VALID: - old_lr_model, new_lr_model = self.hparams.lr_annealing_model( - stage_stats["loss"] - ) - old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( - stage_stats["loss"] - ) - sb.nnet.schedulers.update_learning_rate( - self.model_optimizer, new_lr_model - ) - sb.nnet.schedulers.update_learning_rate( - self.weights_optimizer, new_lr_weights - ) - - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr_model": old_lr_model}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"WER": stage_stats["WER"]}, min_keys=["WER"], - ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - with open(self.hparams.test_wer_file, "w") as w: - self.wer_metric.write_stats(w) - - def init_optimizers(self): - "Initializes the weights optimizer and model optimizer" - self.weights_optimizer = self.hparams.weights_opt_class( - [self.modules.weighted_ssl_model.weights] - ) - self.model_optimizer = self.hparams.model_opt_class( - self.hparams.model.parameters() - ) - self.optimizers_dict = { - "weights_optimizer": self.weights_optimizer, - "model_optimizer": self.model_optimizer, - } - # Initializing the weights - if self.checkpointer is not None: - self.checkpointer.add_recoverable("modelopt", self.model_optimizer) - self.checkpointer.add_recoverable( - "weights_opt", self.weights_optimizer - ) - - -def dataio_prepare(hparams): - """This function prepares the datasets to be used in the brain class. - It also defines the data processing pipeline through user-defined functions.""" - data_folder = hparams["data_folder"] - - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, - ) - - if hparams["sorting"] == "ascending": - # we sort training data to speed up training and get better results. - train_data = train_data.filtered_sorted(sort_key="duration") - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "descending": - train_data = train_data.filtered_sorted( - sort_key="duration", reverse=True - ) - # when sorting do not shuffle in dataloader ! otherwise is pointless - hparams["train_dataloader_opts"]["shuffle"] = False - - elif hparams["sorting"] == "random": - pass - - else: - raise NotImplementedError( - "sorting must be random, ascending or descending" - ) - - valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, - ) - valid_data = valid_data.filtered_sorted(sort_key="duration") - - # test is separate - test_datasets = {} - for csv_file in hparams["test_csv"]: - name = Path(csv_file).stem - test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( - csv_path=csv_file, replacements={"data_root": data_folder} - ) - test_datasets[name] = test_datasets[name].filtered_sorted( - sort_key="duration" - ) - - datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] - - # 2. Define audio pipeline: - @sb.utils.data_pipeline.takes("wav") - @sb.utils.data_pipeline.provides("sig") - def audio_pipeline(wav): - sig = sb.dataio.dataio.read_audio(wav) - return sig - - sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) - label_encoder = sb.dataio.encoder.CTCTextEncoder() - - # 3. Define text pipeline: - @sb.utils.data_pipeline.takes("wrd") - @sb.utils.data_pipeline.provides( - "wrd", "char_list", "tokens_list", "tokens" - ) - def text_pipeline(wrd): - yield wrd - char_list = list(wrd) - yield char_list - tokens_list = label_encoder.encode_sequence(char_list) - yield tokens_list - tokens = torch.LongTensor(tokens_list) - yield tokens - - sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) - - lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") - special_labels = { - "blank_label": hparams["blank_index"], - "unk_label": hparams["unk_index"], - } - label_encoder.load_or_create( - path=lab_enc_file, - from_didatasets=[train_data], - output_key="char_list", - special_labels=special_labels, - sequence_input=True, - ) - - # 4. Set output: - sb.dataio.dataset.set_output_keys( - datasets, ["id", "sig", "wrd", "char_list", "tokens"], - ) - return train_data, valid_data, test_datasets, label_encoder - - -if __name__ == "__main__": - - # CLI: - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - - # If distributed_launch=True then - # create ddp_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset prep (parsing Librispeech) - from librispeech_prepare import prepare_librispeech # noqa - - # multi-gpu (ddp) save data preparation - run_on_main( - prepare_librispeech, - kwargs={ - "data_folder": hparams["data_folder"], - "tr_splits": hparams["train_splits"], - "dev_splits": hparams["dev_splits"], - "te_splits": hparams["test_splits"], - "save_folder": hparams["output_folder"], - "merge_lst": hparams["train_splits"], - "merge_name": "train.csv", - "skip_prep": hparams["skip_prep"], - }, - ) - - # here we create the datasets objects as well as tokenization and encoding - train_data, valid_data, test_datasets, label_encoder = dataio_prepare( - hparams - ) - - # Trainer initialization - asr_brain = ASR( - modules=hparams["modules"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # We dynamicaly add the tokenizer to our brain class. - asr_brain.tokenizer = label_encoder - - ind2lab = label_encoder.ind2lab - vocab_list = [ind2lab[x] for x in range(len(ind2lab))] - - from speechbrain.decoders.ctc import CTCBeamSearcher - - test_searcher = CTCBeamSearcher( - **hparams["test_beam_search"], vocab_list=vocab_list, - ) - - # Training - asr_brain.fit( - asr_brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_opts"], - valid_loader_kwargs=hparams["valid_dataloader_opts"], - ) - - # Testing - if not os.path.exists(hparams["output_wer_folder"]): - os.makedirs(hparams["output_wer_folder"]) - - for k in test_datasets.keys(): # keys are test_clean, test_other etc - asr_brain.hparams.test_wer_file = os.path.join( - hparams["output_wer_folder"], f"wer_{k}.txt" - ) - asr_brain.evaluate( - test_datasets[k], - test_loader_kwargs=hparams["test_dataloader_opts"], - min_key="WER", - ) diff --git a/benchmarks/DASB/LibriSpeech/ASR/hparams/LSTM/train.yaml b/benchmarks/DASB/LibriSpeech/ASR/hparams/LSTM/train.yaml new file mode 100644 index 000000000..8b9581dc9 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR/hparams/LSTM/train.yaml @@ -0,0 +1,220 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Auido Tokenizer: Encodec +# Encoder: LSTM Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: +# - Pooneh Mousavi 2024 +# - Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +run_name: !PLACEHOLDER +output_folder: !ref results/LSTM// +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt +testing: True # If set to True, the test evlaution is done, otherwise skipped. + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +cached_data_folder: !PLACEHOLDER # e.g., path/to/cache +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100"] #["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + +tokens_folder: !PLACEHOLDER # Path to the folder where extracted tokens are saved. +pretrain_embeddings_folder: none # Optional: If pretrain_embeddings is True, this should be set to the path where the pretrained embeddings are saved. + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size_exponent: 4 # @orion_step1: --batch_size_exponent~"uniform(2, 4,discrete=True)" +batch_size: !ref 2 ** +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.0001 # @orion_step1: --lr_model~"loguniform(0.00001,0.5)" +weight_decay: 0.0005 + + +# Training parameters +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +# These parameters should be set according to the tokenizer used to extract tokens saved in . +vocab_size: 1024 +num_codebooks: 2 +sample_rate: 24000 + +# Feature parameters +encoder_dim: 1024 +# If set to True, encoder_dim should match the dimension of the tokenizer. For Encodec, it is 128. +pretrain_embeddings: False +freeze_embedding: False + +# LSTM +activation: !name:torch.nn.Sigmoid +dnn_layers: 2 # @orion_step1: --dnn_layers~"uniform(1, 4,discrete=True)" +dnn_neurons: 1024 +dropout: 0.2 +output_neurons: 31 + +# BPE parameters +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +tokens_loader: !new:utils.tokens.TokensLoader + data_path: !ref + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: !ref + bidirectional: True + dropout: !ref + hidden_size: !ref + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + # tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR/hparams/contextnet/train.yaml b/benchmarks/DASB/LibriSpeech/ASR/hparams/contextnet/train.yaml new file mode 100644 index 000000000..eab197c68 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR/hparams/contextnet/train.yaml @@ -0,0 +1,214 @@ +# ############################################################################ +# Model: E2E ASR with CTC +# Auido Tokenizer: Encodec +# Encoder: LSTM Encoder +# Decoder: CTC beam searcher and greedy searcher +# Tokens: character +# Training: Librispeech 960h +# Authors: +# - Pooneh Mousavi 2024 +# - Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +run_name: !PLACEHOLDER +output_folder: !ref results/LSTM// +output_wer_folder: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt +testing: True # If set to True, the test evlaution is done, otherwise skipped. + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +cached_data_folder: !PLACEHOLDER # e.g., path/to/cache +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100"] #["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + +tokens_folder: !PLACEHOLDER # Path to the folder where extracted tokens are saved. +pretrain_embeddings_folder: none # Optional: If pretrain_embeddings is True, this should be set to the path where the pretrained embeddings are saved. + +####################### Training Parameters #################################### +number_of_epochs: 20 +batch_size_exponent: 4 # @orion_step1: --batch_size_exponent~"uniform(2, 4,discrete=True)" +batch_size: !ref 2 ** +test_batch_size: 1 +grad_accumulation_factor: 2 +max_grad_norm: 5.0 +sorting: descending #random +num_workers: 8 +loss_reduction: batchmean +precision: fp32 # bf16, fp16 or fp32loss_reduction: batchmean +valid_search_interval: 1 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation +cache_size: 1.e+10 + +lr_model: 0.0001 # @orion_step1: --lr_model~"loguniform(0.00001,0.5)" +weight_decay: 0.0005 + + +# Training parameters +dynamic_batching: True +max_batch_length_train: 850 +max_batch_len_val: 100 +num_bucket: 200 +shuffle: False # if true re-creates batches at each epoch shuffling examples. +max_batch_ex: 128 +batch_ordering: random + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_val: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +####################### Model parameters ########################### +# Tokenizer parameters +# These parameters should be set according to the tokenizer used to extract tokens saved in . +vocab_size: 1024 +num_codebooks: 2 +sample_rate: 24000 + +# Feature parameters +encoder_dim: 1024 +# If set to True, encoder_dim should match the dimension of the tokenizer. For Encodec, it is 128. +pretrain_embeddings: False +freeze_embedding: False + +# Contextnet + +output_neurons: 31 + +# BPE parameters +# BPE parameters +token_type: char # ["unigram", "bpe", "char"] +character_coverage: 1.0 +blank_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +beam_size: 100 +beam_prune_logp: -12.0 +token_prune_min_logp: -1.2 +prune_history: False + +############################## models ################################ +tokens_loader: !new:utils.tokens.TokensLoader + data_path: !ref + +discrete_embedding_layer: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + # hidden_dim: !ref + freeze: !ref + init: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.ContextNet.ContextNet + input_shape: [null, null, !ref ] + strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 640 + n_neurons: !ref + +modules: + encoder: !ref + ctc_lin: !ref + attention_mlp: !ref + # tokenizer: !ref + discrete_embedding_layer: !ref + + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### +# Decoding parameters +test_beam_search: + blank_index: !ref + beam_size: !ref + beam_prune_logp: !ref + token_prune_min_logp: !ref + prune_history: !ref + alpha: 0.8 + beta: 1.2 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +model_opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + weight_decay: !ref + +############################## Logging and Pretrainer ########################## +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + + +# Functions and classes +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True +wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/benchmarks/DASB/LibriSpeech/ASR/librispeech_prepare.py b/benchmarks/DASB/LibriSpeech/ASR/librispeech_prepare.py new file mode 120000 index 000000000..a3126ec94 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR/librispeech_prepare.py @@ -0,0 +1 @@ +../librispeech_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriSpeech/ASR/train.py b/benchmarks/DASB/LibriSpeech/ASR/train.py new file mode 100644 index 000000000..ec6ac1b42 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/ASR/train.py @@ -0,0 +1,457 @@ +#!/usr/bin/env/python3 +"""Recipe for training an discrete tokens ctc ASR system with librispeech. + +Decoding is performed with greedy decoding at validation time. +At test time, beamsearch is used with an optional external language model. + +Authors + * Pooneh Mousavi 2024 + * Jarod Duret 2024 +""" + +import os +import sys +import time +import torch +import torchaudio +import logging +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main, if_main_process +from speechbrain.tokenizers.SentencePiece import SentencePiece +from hyperpyyaml import load_hyperpyyaml +from pathlib import Path + +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) + + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + in_toks, _ = batch.speech_tokens + + in_embs = self.modules.discrete_embedding_layer( + in_toks + ) # [B, T, N-Q, D] + + # Attention-Pooling + att_w = self.modules.attention_mlp(in_embs) # [B, T, N-Q, 1] + in_embs = torch.matmul(att_w.transpose(2, -1), in_embs).squeeze( + -2 + ) # [B, T, D] + + # forward modules + if type(self.modules.encoder).__name__ == "ContextNet": + enc_out = self.modules.encoder(in_embs) + + elif type(self.modules.encoder).__name__ == "LSTM": + enc_out, _ = self.modules.encoder(in_embs) + + else: + raise NotImplementedError + + # output layer for ctc log-probabilities + logits = self.modules.ctc_lin(enc_out) + p_ctc = self.hparams.log_softmax(logits) + + p_tokens = None + if stage == sb.Stage.VALID: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + elif stage == sb.Stage.TEST: + p_tokens = test_searcher(p_ctc, wav_lens) + + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + ids = batch.id + tokens, tokens_lens = batch.tokens + + # Label Augmentation + if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"): + tokens = self.hparams.wav_augment.replicate_labels(tokens) + tokens_lens = self.hparams.wav_augment.replicate_labels(tokens_lens) + + loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + + if stage == sb.Stage.VALID: + # Decode token terms to words + predicted_words = self.tokenizer( + predicted_tokens, task="decode_from_list" + ) + elif stage == sb.Stage.TEST: + predicted_words = [ + hyp[0].text.split(" ") for hyp in predicted_tokens + ] + + if stage != sb.Stage.TRAIN: + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.wer_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of a epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + if ( + current_epoch % valid_search_interval == 0 + or stage == sb.Stage.TEST + ): + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # log stats and save checkpoint at end-of-epoch + if stage == sb.Stage.VALID: + if type(self.hparams.scheduler).__name__ == "NewBobScheduler": + lr, new_lr = self.hparams.scheduler(stage_stats["loss"]) + sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) + elif type(self.hparams.scheduler).__name__ == "LinearNoamScheduler": + lr = self.hparams.scheduler.current_lr + else: + raise NotImplementedError + + optimizer = self.optimizer.__class__.__name__ + epoch_stats = { + "epoch": epoch, + "lr": lr, + "optimizer": optimizer, + } + self.hparams.train_logger.log_stats( + stats_meta=epoch_stats, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"], "epoch": epoch}, + min_keys=["WER"], + num_to_keep=self.hparams.avg_checkpoints, + ) + + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + if if_main_process(): + with open( + self.hparams.output_wer_folder, "w", encoding="utf-8" + ) as w: + self.wer_metric.write_stats(w) + + def on_fit_batch_end(self, batch, outputs, loss, should_step): + if ( + should_step + and type(self.hparams.scheduler).__name__ == "LinearNoamScheduler" + ): + self.hparams.scheduler(self.optimizer) + + +def dataio_prepare(hparams, tokenizer): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions. + """ + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + # test is separate + test_datasets = {} + for csv_file in hparams["test_csv"]: + name = Path(csv_file).stem + test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=csv_file, replacements={"data_root": data_folder} + ) + test_datasets[name] = test_datasets[name].filtered_sorted( + sort_key="duration" + ) + + datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] + + # 1. Define tokens pipeline: + tokens_loader = hparams["tokens_loader"] + num_codebooks = hparams["num_codebooks"] + + @sb.utils.data_pipeline.takes("id") + @sb.utils.data_pipeline.provides("speech_tokens") + def tokens_pipeline(id): + tokens = tokens_loader.tokens_by_uttid(id, num_codebooks=num_codebooks) + return tokens + + sb.dataio.dataset.add_dynamic_item(datasets, tokens_pipeline) + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + info = torchaudio.info(wav) + resampled = torchaudio.transforms.Resample( + info.sample_rate, hparams["sample_rate"], + )(sig) + # resampled = resampled.unsqueeze(0) + return resampled + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides( + "wrd", "char_list", "tokens_list", "tokens" + ) + def text_pipeline(wrd): + yield wrd + char_list = list(wrd) + yield char_list + tokens_list = tokenizer.sp.encode_as_ids(wrd) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "wrd", "char_list", "tokens", "speech_tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from speechbrain.dataio.sampler import DynamicBatchSampler # noqa + + dynamic_hparams_train = hparams["dynamic_batch_sampler_train"] + dynamic_hparams_val = hparams["dynamic_batch_sampler_val"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_train, + ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + length_func=lambda x: x["duration"], + **dynamic_hparams_val, + ) + + return ( + train_data, + valid_data, + test_datasets, + train_batch_sampler, + valid_batch_sampler, + ) + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from librispeech_prepare import prepare_librispeech # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_librispeech, + kwargs={ + "data_folder": hparams["data_folder"], + "tr_splits": hparams["train_splits"], + "dev_splits": hparams["dev_splits"], + "te_splits": hparams["test_splits"], + "save_folder": hparams["cached_data_folder"], + "merge_lst": hparams["train_splits"], + "merge_name": "train.csv", + "skip_prep": hparams["skip_prep"], + }, + ) + + # Defining tokenizer and loading it + tokenizer = SentencePiece( + model_dir=hparams["cached_data_folder"], + vocab_size=hparams["output_neurons"], + annotation_train=hparams["train_csv"], + annotation_read="wrd", + model_type=hparams["token_type"], + character_coverage=hparams["character_coverage"], + bos_id=hparams["bos_index"], + eos_id=hparams["eos_index"], + ) + + # here we create the datasets objects as well as tokenization and encoding + ( + train_data, + valid_data, + test_datasets, + train_bsampler, + valid_bsampler, + ) = dataio_prepare(hparams, tokenizer) + + # Use pretrained embeddings + if hparams["pretrain_embeddings"]: + tokens_loader = hparams["tokens_loader"] + embs = tokens_loader.load_pretrained_embeddings( + hparams["pretain_embeddings_folder"] + ) + if isinstance(hparams["num_codebooks"], int): + embs = embs[ + : hparams["num_codebooks"] * hparams["vocab_size"], + ] + # For discrete SSL, num_codebooks is a list used to determine which layers to use. + # It is not sequential and can be, for example, [0, 1] or [1, 4]. + elif isinstance(hparams["num_codebooks"], list): + indices = [ + i + for codebook_idx in hparams["num_codebooks"] + for i in range( + codebook_idx * hparams["vocab_size"], + (codebook_idx + 1) * hparams["vocab_size"], + ) + ] + indices = torch.tensor(indices, dtype=torch.long) + embs = embs[indices] + hparams["discrete_embedding_layer"].init_embedding(embs) + + # Log number of parameters/buffers + model_params = sum( + [ + x.numel() + for module in hparams["modules"].values() + for x in module.state_dict().values() + ] + ) + hparams["train_logger"].log_stats( + stats_meta={ + "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", + }, + ) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + opt_class=hparams["model_opt_class"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Adding objects to trainer. + asr_brain.tokenizer = tokenizer + vocab_list = [ + tokenizer.sp.id_to_piece(i) for i in range(tokenizer.sp.vocab_size()) + ] + + from speechbrain.decoders.ctc import CTCBeamSearcher + + test_searcher = CTCBeamSearcher( + **hparams["test_beam_search"], vocab_list=vocab_list, + ) + + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + # Measure time + start_time = time.time() # Start the timer + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + end_time = time.time() # End the timer + # Calculate elapsed time + elapsed_time = end_time - start_time + logger.info(f"Model execution time: {elapsed_time:.6f} seconds") + + if hparams["testing"]: + # Testing + if not os.path.exists(hparams["output_wer_folder"]): + os.makedirs(hparams["output_wer_folder"]) + + for k in test_datasets.keys(): # keys are test_clean, test_other etc + asr_brain.hparams.output_wer_folder = os.path.join( + hparams["output_wer_folder"], f"wer_{k}.txt" + ) + asr_brain.evaluate( + test_datasets[k], + test_loader_kwargs=hparams["test_dataloader_opts"], + min_key="WER", + ) diff --git a/benchmarks/DASB/LibriSpeech/extraction/extract.py b/benchmarks/DASB/LibriSpeech/extraction/extract.py new file mode 100644 index 000000000..3979ba731 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/extraction/extract.py @@ -0,0 +1,96 @@ +#!/usr/bin/env/python3 +"""Recipe for extracting a discrete tokens with librispeech. + +Authors + * Jarod Duret 2024 +""" + +import os +import sys +import logging +import pathlib as pl +import speechbrain as sb +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml + +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) + +print(base_dir) + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from librispeech_prepare import prepare_librispeech # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_librispeech, + kwargs={ + "data_folder": hparams["data_folder"], + "tr_splits": hparams["train_splits"], + "dev_splits": hparams["dev_splits"], + "te_splits": hparams["test_splits"], + "save_folder": hparams["output_folder"], + "merge_lst": hparams["train_splits"], + "merge_name": "train.csv", + "skip_prep": hparams["skip_prep"], + }, + ) + + tokens_extractor = hparams["tokens_extractor"] + data_folder = hparams["data_folder"] + datasets = [] + for split in ["train", "valid"]: + csv_path = hparams[f"{split}_csv"] + name = pl.Path(csv_path).stem + dataset = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=csv_path, replacements={"data_root": data_folder}, + ) + datasets.append(dataset) + + for split in hparams["test_csv"]: + name = pl.Path(split).stem + dataset = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=split, replacements={"data_root": data_folder}, + ) + datasets.append(dataset) + + merged_data = { + key: value + for dataset in datasets + for key, value in dataset.data.items() + } + merged_dataset = DynamicItemDataset(merged_data) + + save_folder = pl.Path(hparams["save_folder"]) + logger.info("Extracting dataset tokens ...") + tokens_extractor.extract_tokens( + merged_dataset, + hparams["num_codebooks"], + (save_folder / "librispeech").as_posix(), + ) + + if hparams["save_embedding"]: + save_folder = pl.Path(hparams["save_folder"]) + logger.info(f"Saving embeddings ...") + tokens_extractor.save_pretrained_embeddings( + (save_folder / "embeddings").as_posix(), + vocab_size=hparams["vocab_size"], + num_codebooks=hparams["num_codebooks"], + ) diff --git a/benchmarks/DASB/LibriSpeech/extraction/hparams/dac.yaml b/benchmarks/DASB/LibriSpeech/extraction/hparams/dac.yaml new file mode 100644 index 000000000..d2d935ed0 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/extraction/hparams/dac.yaml @@ -0,0 +1,65 @@ +# ############################################################################ +# Auido Tokenizer: DAC +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/dac +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +train_splits: ["train-clean-100"] #, "train-clean-360", "train-other-500" +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +batch_size: 8 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +####################### Model parameters ########################### +# Tokenizer parameters +# DAC parameters +# model_type: [16khz, 24khz, 44khz, 44khz] +# vocab_size: [1024, 1024, 1024, 1024] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +# max_num_codebooks: [12, 32, 9, 18] +# embedding_dim: [1024, 1024, 1024, 128] +model_type: 24khz +vocab_size: 1024 +model_bitrate: 8kbps +num_codebooks: 32 +sample_rate: 24000 +# Feature parameters +encoder_dim: 1024 +save_embedding: False + +tokenizer: !new:utils.tokenizer_interface.DACTokenizer + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/LibriSpeech/extraction/hparams/discrete_ssl.yaml b/benchmarks/DASB/LibriSpeech/extraction/hparams/discrete_ssl.yaml new file mode 100644 index 000000000..7d4938625 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/extraction/hparams/discrete_ssl.yaml @@ -0,0 +1,102 @@ +# ############################################################################ +# Auido Tokenizer: WavLM +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/wavlm +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +train_splits: ["train-clean-100"] #, "train-clean-360", "train-other-500" +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +batch_size: 8 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +### Configuration for discrete SSL model +# | SSL Model | HF Encoder | K-Means Dataset | K-Means Size | SSL Layers | Vocoder Model | +# |------------|----------------------------------------|-----------------|--------------|----------------------|------------------------------------------| +# | WavLM | microsoft/wavlm-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wavlm-k1000-LibriTTS | +# | HuBERT | facebook/hubert-large-ll60k | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | WIP | +# | Wav2Vec2 | facebook/wav2vec2-large-960h-lv60-self | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | WIP | + +# ssl_model_type: hubert, wavlm, wav2vec2 +# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large +ssl_model_type: WavLM +ssl_hub: microsoft/wavlm-large +ssl_folder: !ref /ssl_checkpoint +kmeans_cache_dir: !ref /kmeans_checkpoint +kmeans_dataset: LibriSpeech +vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS +freeze_ssl: True +freeze_feature_extractor: True +vocab_size: 1000 +save_embedding: False + +### Config for Tokenizer +# Layer number should be among the supported layers for discrete SSL models(kmenas model should be available for that layer) +num_codebooks: [1, 3, 7, 12, 18, 23] +deduplicate: [False, False, False, False, False, False] +bpe_tokenizer_path: [null, null, null, null, null, null] +sample_rate: 16000 +encoder_dim: 1024 + +ssl_model: !apply:speechbrain.utils.hparams.choice + value: !ref + choices: + WavLM: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + HuBERT: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + Wav2Vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + +tokenizer: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !ref + vocoder_repo_id: !ref + kmeans_dataset: !ref + num_clusters: !ref + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/LibriSpeech/extraction/hparams/encodec.yaml b/benchmarks/DASB/LibriSpeech/extraction/hparams/encodec.yaml new file mode 100644 index 000000000..ee0a7e910 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/extraction/hparams/encodec.yaml @@ -0,0 +1,63 @@ +# ############################################################################ +# Auido Tokenizer: Encodec +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/encodec +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +train_splits: ["train-clean-100"] #, "train-clean-360", "train-other-500" +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +batch_size: 8 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +# num_codebooks: [2, 4, 8, 16, 32] +bandwidth: 24.0 +num_codebooks: 32 +vocab_size: 1024 +sample_rate: 24000 +save_embedding: False + +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/LibriSpeech/extraction/hparams/speech_tokenizer.yaml b/benchmarks/DASB/LibriSpeech/extraction/hparams/speech_tokenizer.yaml new file mode 100644 index 000000000..5d897a782 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/extraction/hparams/speech_tokenizer.yaml @@ -0,0 +1,54 @@ +# ############################################################################ +# Auido Tokenizer: Speech Tokenizer +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/speech_tokenizer +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +train_splits: ["train-clean-100"] #, "train-clean-360", "train-other-500" +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +batch_size: 8 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +vocab_size: 1024 +num_codebooks: 8 +sample_rate: 16000 +encoder_dim: 1024 +freeze_embedding: False +save_embedding: False + +# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec) +tokenizer: !new:utils.tokenizer_interface.SpeechTokenizer + source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio + save_path: !ref + +tokens_extractor: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/LibriSpeech/extraction/librispeech_prepare.py b/benchmarks/DASB/LibriSpeech/extraction/librispeech_prepare.py new file mode 120000 index 000000000..a3126ec94 --- /dev/null +++ b/benchmarks/DASB/LibriSpeech/extraction/librispeech_prepare.py @@ -0,0 +1 @@ +../librispeech_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/README.md b/benchmarks/DASB/README.md index c3e42bf64..0ad632979 100644 --- a/benchmarks/DASB/README.md +++ b/benchmarks/DASB/README.md @@ -25,17 +25,31 @@ For detailed information, refer to [paper](https://arxiv.org/pdf/2406.14294): # Table of Contents -- [Table of Contents](#table-of-contents) -- [Installation](#-installation) -- [Discrete Audio Encoder](#-Discrete-Audio-Encoder) -- [Datasets and Recipes](#-Datasets-and-Recipes) -- [Quickstart](#-quickstart) - - [Running a single task](#Running-a-single-task) - - [Running multiple tasks](#Runnin-multiple-tasks) -- [‍Incorporating Your Audio Tokenizer](#-Incorporating-Your-Audio-Tokenizer) -- [Results](#-results) -- [Contact](#-contact) -- [Citing](#-citing) +Here’s the updated **Table of Contents** for your GitHub README with corrections and better alignment: + +--- + +# 📑 Table of Contents + +- [DASB - Discrete Audio and Speech Benchmark](#dasb---discrete-audio-and-speech-benchmark) +- [🛠️ Installation](#-installation) +- [🎌 Discrete Audio Encoder](#-discrete-audio-encoder) +- [⚡ Datasets and Recipes](#-datasets-and-recipes) +- [📖 Training Scenarios](#-training-scenarios) + - [On-the-Fly Token Extraction](#on-the-fly-token-extraction) + - [Offline Token Extraction](#offline-token-extraction) +- [🎛️ Hyperparameter Tuning](#%EF%B8%8F-hyperparameter-tuning) +- [📝 Incorporating Your Audio Tokenizer](#-incorporating-your-audio-tokenizer) +- [📈 Results](#-results) + - [Ranking](#ranking) + - [Benchmarking Results for Discriminative Tasks](#benchmarking-results-for-discriminative-tasks) + - [Benchmarking Results for Generative Tasks](#benchmarking-results-for-generative-tasks) +- [📧 Contact](#-contact) +- [📖 Citing](#-citing) + +--- + +This structure provides a clear and logical flow, ensuring users can easily navigate the document. Each major section is linked appropriately, with sub-sections for detailed content. Let me know if additional adjustments are required! # 🛠️ Installation @@ -98,51 +112,164 @@ To set up SpeechBrain-DASB, follow these steps: | Libri2Mix | Speech Separation | Conformer | CRDNN | [github.com/JorisCos/LibriMix](https://github.com/JorisCos/LibriMix) | | LJSpeech | Text-to-Speech | Shallow Transformer | Deep Transformer | [keithito.com/LJ-Speech-Dataset/](https://keithito.com/LJ-Speech-Dataset/) | -# ▶️ Quickstart +# 📖 Training Scenarios -## Running a single task +We offer two different training scenarios: **on-the-fly token extraction** and **offline token extraction**. -If you have specific discrete model and want to benchmark it for a specific task, you need to run the following command: - ``` - python LibriSpeech/ASR/LSTM/train_[tokenzier_name].py LibriSpeech/ASR/LSTM/hparams/train_[tokenzier_name].yaml --output_folder my-output-folder --data_folder mypath/to/LibriSpeech - ``` +## On-the-Fly Token Extraction +In this scenario, audio tokens are extracted dynamically during training. To enhance efficiency, we use a caching mechanism where tokens are saved in memory during the first epoch and retrieved for subsequent epochs. However, this approach has some limitations: +- It works best when the dataset is small, the bitrate is low, and batching is sorted (not random). +- It is unsuitable when data augmentation is required. -## Running multiple tasks +You can also disable the caching mechanism if needed. -To run all tasks, make the following changes: +Currently, the on-the-fly token extraction is applied only in the recipe located at: +`LibriSpeech/ASR-on-the-fly` -1. Edit the `run_discriminative_benchmark.sh` and `run_genarative_benchmark.sh` files and modify tokenizer related values for example the bitrate , number of codebooks, and etc. -2. Choose a set of tasks from the provided list and, for each task, select a downstream architecture from the available options (see list below). -3. Update the variables defined in `run_benchmark.sh` with two lists of equal size. In the `ConsideredTasks` list, specify the tasks you want to run (e.g., `'LibriSpeechASR' 'LibriSpeechASR' 'IEMOCAP'`). In the `Downstreams` list, specify the corresponding downstream architecture for each task (e.g., `'BiLSTM'`, `contextnet`, `'ecapa_tdnn'`). +If you wish to adapt this strategy for your own recipe, you can copy and modify the existing recipe as needed. Here's how to run the on-the-fly recipe: - For example, if you set `ConsideredTasks=('LibriSpeechASR' 'LibriSpeechASR' 'IEMOCAP')` and `Downstreams=('BiLSTM', 'contextnet', 'ecapa_tdnn')`, the benchmark will be executed as follows: - - LibriSpeechASR with BiLSTM as the probing head - - LibriSpeechASR with contextnet as the probing head - - IEMOCAP with ecapa_tdnn as the probing head. +```bash +python LibriSpeech/ASR-on-the-fly/train.py LibriSpeech/ASR-on-the-fly/hparams/LSTM/{TOKENIZER}.yaml --data_folder=path/LibriSpeech --output_folder=path/results/LibriSpeech/ASR/{TOKENIZER}/LSTM +``` + +> **Note:** On-the-fly extraction can be time-consuming, which is why we also provide an alternative approach: **offline token extraction**. + + +## Offline Token Extraction +In this scenario, all tokens are pre-extracted in a separate recipe. We recommend using the highest number of codebooks available for token extraction and then choosing the desired settings during training. + +### Token Extraction Command +To extract tokens, use the following command: + +```bash +python LibriSpeech/extraction/extract.py benchmarks/DASB/LibriSpeech/extraction/hparams/{tokenizer}.yaml --data_folder=path/LibriSpeech --num_codebooks=32 +``` + +If you wish to initialize your embedding layer with the tokenizer's embeddings while training your downstream model, set the flag `save_embedding` to `True`. For discrete SSL tokenizers, you can specify a list of layers for `--num_codebooks` instead of a single number (e.g., `--num_codebooks=[3,7,12]`). + +### Training with Pre-Extracted Tokens +Once tokens are extracted and saved, you can train a downstream model using the following command: + +```bash +bash run_experiments.sh --hparams benchmarks/DASB/LibriSpeech/ASR/hparams/LSTM/train.yaml --data_folder LibriSpeech --cached_data_folder cache/ --output_folder results/LibriSpeech/ASR/encodec/LSTM --task ASR --dataset LibriSpeech --seed 1986 --nruns 2 --eval_metric WER --tokens_folder LibriSpeech/extraction-emb/speech_tokenizer/save/librispeech/ +``` + +--- + +This workflow ensures flexibility, efficiency, and reproducibility for both training scenarios. Adapt the recipes as needed for your specific requirements! + +Here's a polished and formatted version for clarity, suitable for a README or documentation: + + +# 🎛️ Hyperparameter Tuning + +Efficient hyperparameter tuning is critical when introducing novel models or experimenting with diverse datasets. Our benchmark establishes a standardized protocol for hyperparameter tuning, leveraging [Orion](https://orion.readthedocs.io/en/stable/) to ensure fair and consistent model comparisons. + +--- + +## **Overview** + +Hyperparameter tuning is managed using the `./run_hparam_optimization.sh` script. This script coordinates multiple hyperparameter trials via `run_experiments.sh`. -3. Run the following command: - ``` - bash run_discriminative_benchmark.sh [tokenzier_name] - bash run_genarative_benchmark.sh [tokenzier_name] - ``` - You could also pass extra arguments as far as they are consistent across all tasks. - For generative task, make sure to set the `utmos_path` required for TTS evaluation. + +## **Incorporating Orion Flags in Hparam Files** + +To enable tuning, Orion flags should be directly embedded in the YAML hparam file using comments. For example, to optimize the learning rate (`lr`) parameter within a defined range, include the following line in the YAML file: + +```yaml +lr_model: 0.0001 # @orion_step1: --lr_model~"loguniform(0.00001,0.5)" +``` + + + +## **Workflow of the Script** + +The script operates as follows: + +1. **Scans** the YAML hparam file for Orion flags. +2. **Executes** hyperparameter tuning using the `orion-hunt` command. +3. **Saves** the best hyperparameters for reference via `torch-info`. +4. **Iterates** until encountering flags such as `@orion_step` in the YAML file. + + + +## **Running Hyperparameter Optimization** + +You can perform hyperparameter optimization using a command like this: + +```bash +bash run_hparam_optimization.sh \ + --exp_name 'ASR-encodec-LSTM_hopt' \ + --hparams LibriSpeech/ASR/hparams/LSTM/train.yaml \ + --data_folder path/LibriSpeech \ + --cached_data_folder path/cache/ \ + --output_folder results/LibriSpeech/ASR/encodec/LSTM \ + --task ASR \ + --dataset LibriSpeech \ + --seed 1986 \ + --nruns 1 \ + --nruns_eval 5 \ + --eval_metric WER \ + --exp_max_trials 50 \ + --tokens_folder results/LibriSpeech/extraction-emb/encodec/save/librispeech/ \ + --run_name encodec +``` + +For more details on the arguments and customization options, refer to `./run_hparam_optimization.sh`. + + +### **Notes** + +1. **Execution Time**: + - Hyperparameter tuning may take several hours or even days, depending on the model complexity and dataset. + +2. **GPU vs. CPU**: + - By default, models are trained on GPU. To train on CPU instead, include the `--device cpu` flag. + +3. **Monitoring Progress**: + - Use the following command to monitor optimization status: + ```bash + orion status --all + ``` + - Ensure that Orion-specific environment variables are set in your bash environment. For example: + ```bash + export ORION_DB_ADDRESS=results/LibriSpeech/ASR/encodec/LSTM/hopt/ASR-encodec-LSTM_hopt.pkl + export ORION_DB_TYPE=pickleddb + ``` + Adjust `ORION_DB_ADDRESS` according to your experiment. + +4. **Resuming Optimization**: + - You can interrupt the script at any point. It will resume from the last completed trial. + +5. **Repetition of Optimization**: + - For multiple repetitions of the same hyperparameter optimization, modify the `--exp_name` parameter. + +6. **System Requirements**: + - The script is designed for Linux-based systems. A bash script is provided instead of Python due to its ability to manage diverse training loops across various subjects and sessions. + +--- + +This protocol ensures fair model comparison across diverse tasks and datasets. All reported results are derived using this standardized hyperparameter tuning methodology, enabling consistent assessments across models. + # 📝 ‍Incorporating Your Audio Tokenizer Let's now assume you've designed an audio and speech tokenizer in PyTorch and wish to integrate it into our benchmark. You're in luck because we've made this step as simple as possible for you! Here are the steps you should follow: -1. Write your model's code in a Python library saved in `benchmarks/DASB/model` (e.g., `benchmarks/MOABB/models/my_model.py`). -2. Create a YAML and py file for each task you want to experiment with. Thankfully, you don't have to start from scratch. For example, if you're working with LibriSpeech/ASR/LSTM, copy `benchmarks/DASB/LibriSpeech/ASR/contextnet/hparams/train_encodec.yaml` and save it in the same folder with a different name (e.g., `train_my_model.yaml` and `train_my_model.py`). +1. Write your model's code in a Python library saved in `benchmarks/DASB/model` (e.g., `benchmarks/DASB/models/my_model.py`). + +2. Add the tokenizer to `utils/tokenizer_interface.py` and ensure the `encode` and `decode` functions are consistent in functionality and output shape with the other tokenizers. + +3. Create a YAML and Python file for each task you want to experiment with. Thankfully, you don't have to start from scratch. For example, you can copy `LibriSpeech/extraction/hparams/encodec.yaml`, adapt it based on your needs, and save it in the same folder with a different name (e.g., `LibriSpeech/extraction/hparams/{YOUR_TOKENIZER_NAME}.yaml`). -3. Edit the relevant section of your `train_my_model.yaml` and `train_my_model.py`. Redefine the `codec:` to reference your custom model (e.g., `codec: !new:models.my_model.my_model`). +4. Edit the relevant sections of your `{YOUR_TOKENIZER_NAME}.yaml`. Redefine the `tokenizer:` field to reference your custom model (e.g., `tokenizer: !new:tokenizer_interface.your_tokenizer`). -4. Ensure you include the hyperparameters specific to your model. +5. Ensure you include the hyperparameters specific to your model. -5. Now, follow the instructions above to run an experiments across tasks. +6. Now, follow the instructions provided earlier to run experiments across tasks. **Note**: If you're not familiar with YAML, you can refer to our [HyperPyYAML tutorial](https://speechbrain.github.io/tutorial_basics.html) on the SpeechBrain website for guidance. # 📈 Results diff --git a/benchmarks/DASB/extra_requirements.txt b/benchmarks/DASB/extra_requirements.txt index 4d1d241c3..e97e16b28 100644 --- a/benchmarks/DASB/extra_requirements.txt +++ b/benchmarks/DASB/extra_requirements.txt @@ -1,7 +1,10 @@ beartype jsonlines +kaldiio librosa>=0.9.2 onnxruntime>=1.16.3 +orion +orion[profet] scikit-learn speechbrain>=1.0.0 speechtokenizer>=0.1.2 diff --git a/benchmarks/DASB/model/ __init__.py b/benchmarks/DASB/model/ __init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/DASB/model/custom_model.py b/benchmarks/DASB/model/custom_model.py index b6e11a0d2..972d35c66 100644 --- a/benchmarks/DASB/model/custom_model.py +++ b/benchmarks/DASB/model/custom_model.py @@ -57,22 +57,31 @@ def __init__( num_codebooks, vocab_size, emb_dim, - pad_index=0, init=False, freeze=False, + hidden_dim=None, ): super(Discrete_EmbeddingLayer, self).__init__() self.vocab_size = vocab_size - self.num_codebooks = num_codebooks + self.num_codebooks = ( + len(num_codebooks) + if isinstance(num_codebooks, list) + else num_codebooks + ) self.freeze = freeze self.embedding = torch.nn.Embedding( - num_codebooks * vocab_size, emb_dim + self.num_codebooks * vocab_size, emb_dim ).requires_grad_(not self.freeze) self.init = init + # Add a linear layer to match dimensions if necessary + if hidden_dim is not None and hidden_dim != emb_dim: + self.proj_layer = torch.nn.Linear(emb_dim, hidden_dim) + else: + self.proj_layer = None + def init_embedding(self, weights): - with torch.no_grad(): - self.embedding.weight = torch.nn.Parameter(weights) + self.embedding.weight.data.copy_(weights) def forward(self, in_tokens): """Computes the embedding for discrete tokens. @@ -97,4 +106,6 @@ def forward(self, in_tokens): ) # Forward Pass to embedding and in_embs = self.embedding(in_tokens) + if self.proj_layer is not None: + in_embs = self.proj_layer(in_embs) return in_embs diff --git a/benchmarks/DASB/orion/hparams_tpe.yaml b/benchmarks/DASB/orion/hparams_tpe.yaml new file mode 100644 index 000000000..fb6a7c9b0 --- /dev/null +++ b/benchmarks/DASB/orion/hparams_tpe.yaml @@ -0,0 +1,6 @@ +experiment: + algorithms: + tpe: + seed: 1986 + n_initial_points: 20 + n_ei_candidates: 24 diff --git a/benchmarks/DASB/run_discriminative_benchmark.sh b/benchmarks/DASB/run_discriminative_benchmark.sh deleted file mode 100644 index 79383deb2..000000000 --- a/benchmarks/DASB/run_discriminative_benchmark.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash -# Please consult the README.md file for instructions on how to run the benchmark. - -tokenizer_name=$1 -if [[ "$tokenizer_name" == "" ]]; then - echo "Usage: run_generative_benchmark.sh " - exit 1 -fi - -output_folder='/path/to/output' -declare -a DatasetsFolders=('path/to/LibriSpeech' 'path/to/CommonVoice' 'path/to/IEMOCAP' 'path/to/SLURP' 'path/to/Google-speech-commands' 'path/to/VoiceCeleb1') -declare -a ConsideredTasks=('LibriSpeech/ASR' 'CommonVoice/ASR' 'IEMOCAP/emotion_recognition' 'SLURP/intent_classification' 'Google-speech-commands/keyword-spotting' 'VoiceCeleb1/speaker_ver') -declare -a DownStreams=('LSTM' 'LSTM' 'ecapa_tdnn' 'LSTM_linear' 'Xvector','Xvector') -declare -a Locales=('cy' 'eu') -declare -a LocalesVobSize=(100 200) - -shift -script_args="$@" - -for i in "${!ConsideredTasks[@]}"; do - task=${ConsideredTasks[i]} - downstream=${DownStreams[i]} - dataset_folder=${DatasetsFolders[i]} - recipe_extra_args="$script_args" - set -- "$recipe_extra_args" - if [[ "$task" == "CommonVoice/ASR" ]]; then - echo "${tokenizer_name}/${task}/${downstream}" - for j in "${!Locales[@]}"; do - locale=${Locales[j]} - vocab=${LocalesVobSize[j]} - python $task/$downstream/train_$tokenizer_name.py $task/$downstream/hparams/train_$tokenizer_name.yaml --output_folder $output_folder/$tokenizer_name/$task/$downstream/$locale --data_folder $dataset_folder/$locale --language $locale --output_neurons $vocab $@ - done - else - python $task/$downstream/train_$tokenizer_name.py $task/$downstream/hparams/train_$tokenizer_name.yaml --output_folder $output_folder/$tokenizer_name/$task/$downstream --data_folder $dataset_folder $@ - fi -done diff --git a/benchmarks/DASB/run_experiments.sh b/benchmarks/DASB/run_experiments.sh new file mode 100755 index 000000000..e0f848aef --- /dev/null +++ b/benchmarks/DASB/run_experiments.sh @@ -0,0 +1,204 @@ +#!/bin/bash + +########################################################### +# Script to run downstream evaluation training, optionally with multiple seeds. +# This script loops over seeds and trains different models. +# At the end, the final performance is computed with the aggregate_results.py script that provides the average performance. +# +# Usage: +# ./run_experiments.sh --hparams benchmarks/DASB/LibriSpeech/ASR/hparams/LSTM/train.yaml --data_folder LibriSpeech --cached_data_folder cache/ \ +# --output_folder results/LibriSpeech/ASR/encodec/LSTM --task ASR --dataset LibriSpeech --seed 1986 --nruns 2 --eval_metric WER --tokens_folder LibriSpeech/extraction-emb/speech_tokenizer/save/librispeech/ + +# +# Authors: +# - Pooneh Mousavi (2024) +########################################################### + +# Initialize variables +hparams="" +data_folder="" +cached_data_folder="" +output_folder="" +task="" +dataset="" +seed="" +nruns="" +eval_metric="acc" +eval_set="test" +rnd_dir=False +additional_flags="" + + +# Function to print argument descriptions and exit +print_argument_descriptions() { + echo "Usage: $0 [options]" + echo "Options:" + echo " --hparams hparams_path Hparam YAML file" + echo " --data_folder data_folder_path Data folder path" + echo " --cached_data_folder cache_path Cached data folder path" + echo " --output_folder output_path Output folder path" + echo " --task task downstream task" + echo " --dataset dataset dataset" + echo " --seed random_seed Seed (random if not specified)" + echo " --nruns num_runs Number of runs" + echo " --eval_metric metric Evaluation metric (e.g., acc or WER)" + echo " --eval_set dev or test Evaluation set. Default: test" + echo " --rnd_dir If True the results are stored in a subdir of the output folder with a random name (useful to store all the results of an hparam tuning). Default: False" + exit 1 +} + + +# Parse command line +POSITIONAL_ARGS=() + +while [[ $# -gt 0 ]]; do + case $1 in + --hparams) + hparams="$2" + shift + shift + ;; + + --data_folder) + data_folder="$2" + shift + shift + ;; + + --cached_data_folder) + cached_data_folder="$2" + shift + shift + ;; + + --output_folder) + output_folder="$2" + shift + shift + ;; + + --task) + task="$2" + shift + shift + ;; + + --dataset) + dataset="$2" + shift + shift + ;; + + --seed) + seed="$2" + shift + shift + ;; + + --nruns) + nruns="$2" + shift + shift + ;; + + --eval_metric) + eval_metric="$2" + shift + shift + ;; + + --eval_set) + eval_set="$2" + shift + shift + ;; + + --rnd_dir) + rnd_dir="$2" + shift + shift + ;; + + + --help) + print_argument_descriptions + ;; + + -*|--*) + additional_flags+="$1 $2 " # store additional flags + shift # past argument + ;; + + + *) + POSITIONAL_ARGS+=("$1") # save positional arg + shift # past argument + ;; + esac +done + + +# Check for required arguments +if [ -z "$hparams" ] ||[ -z "$data_folder" ] || [ -z "$output_folder" ] || [ -z "$nruns" ]; then + echo "ERROR: Missing required arguments! Please provide all required options." + print_argument_descriptions +fi + +# Manage Seed (optional argument) +seed="${seed:-$RANDOM}" + + +if [ "$rnd_dir" = True ]; then + rnd_dirname=$(tr -dc 'a-zA-Z' < /dev/urandom | head -c 6) + output_folder="$output_folder/$rnd_dirname" +fi + +# Make sure the output_folder is created +mkdir -p $output_folder + +# Print command line arguments and save to file +{ + echo "hparams: $hparams" + echo "data_folder: $data_folder" + echo "cached_data_folder: $cached_data_folder" + echo "output_folder: $output_folder" + echo "task: $task" + echo "dataset: $dataset" + echo "seed: $seed" + echo "nruns: $nruns" + echo "eval_metric: $eval_metric" + echo "eval_set: $eval_set" + echo "rnd_dir: $rnd_dir" + echo "additional flags: $additional_flags" +} | tee "$output_folder/flags.txt" + + +# Creating output folder +mkdir -p $output_folder +mkdir -p $data_folder +mkdir -p $cached_data_folder + +# Function to run the training experiment +run_experiment() { + +python $dataset/$task/train.py $hparams --cached_data_folder=$cached_data_folder --seed=$seed --data_folder=$data_folder --output_folder=$output_folder_exp \ +$additional_flags + +} + +# Run multiple training experiments (with different seeds) +for i in $(seq 0 1 $(( nruns - 1 ))); do + ((run_idx = i + 1)) + run_name=run"$run_idx" + output_folder_exp="$output_folder"/"$run_name"/$seed + + run_experiment $output_folder_exp + + + # Changing Random seed + seed=$((seed+1)) +done + + +echo 'Final Results (Performance Aggregation)' +python utils/aggregate_results.py $output_folder "$eval_metric" | tee -a $output_folder/aggregated_performance.txt \ No newline at end of file diff --git a/benchmarks/DASB/run_extraction.sh b/benchmarks/DASB/run_extraction.sh new file mode 100644 index 000000000..92cc81381 --- /dev/null +++ b/benchmarks/DASB/run_extraction.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +########################################################### +# Script to extracts and save tokens from dataset. +# +# Usage: +# ./ $run_extraction.sh --data_folder LibriSpeech --output_folder results/LibriSpeech/ASR/encodec/LSTM --tokenizer encodec --dataset LibriSpeech + +# Authors: +# - Pooneh Mousavi (2024) +########################################################### + +# Initialize variables +data_folder="" +output_folder="" +tokenizer="" +dataset="" +save_embedding=False +additional_flags="" + + +# Function to print argument descriptions and exit +print_argument_descriptions() { + echo "Usage: $0 [options]" + echo "Options:" + echo " --data_folder data_folder_path Data folder path" + echo " --output_folder output_path Output folder path" + echo " --tokenizer tokenizer tokenizer" + echo " --dataset dataset dataset" + echo " --save_embedding save_embedding If True the the embedding are saved. Default: False" + exit 1 +} + + +# Parse command line +POSITIONAL_ARGS=() + +while [[ $# -gt 0 ]]; do + case $1 in + --data_folder) + data_folder="$2" + shift + shift + ;; + + --output_folder) + output_folder="$2" + shift + shift + ;; + + --tokenizer) + tokenizer="$2" + shift + shift + ;; + + --dataset) + dataset="$2" + shift + shift + ;; + + --save_embedding) + save_embedding="$2" + shift + shift + ;; + + --help) + print_argument_descriptions + ;; + + -*|--*) + additional_flags+="$1 $2 " # store additional flags + shift # past argument + ;; + + + *) + POSITIONAL_ARGS+=("$1") # save positional arg + shift # past argument + ;; + esac +done + + +# Check for required arguments +if [ -z "$tokenizer" ] ||[ -z "$data_folder" ] || [ -z "$output_folder" ] || [ -z "$dataset" ]; then + echo "ERROR: Missing required arguments! Please provide all required options." + print_argument_descriptions +fi + + +# Make sure the output_folder is created +mkdir -p $output_folder + +# Print command line arguments and save to file +{ + echo "data_folder: $data_folder" + echo "output_folder: $output_folder" + echo "tokenizer: $tokenizer" + echo "dataset: $dataset" + echo "save_embedding: $save_embedding" + echo "additional flags: $additional_flags" +} | tee "$output_folder/flags.txt" + + +# Creating output folder +mkdir -p $output_folder +mkdir -p $data_folder + +python $dataset/extraction/extract.py $dataset/extraction/hparams/$tokenizer.yaml --data_folder=$data_folder --output_folder=$output_folder --save_embedding=$save_embedding \ +$additional_flags diff --git a/benchmarks/DASB/run_generative_benchmark.sh b/benchmarks/DASB/run_generative_benchmark.sh deleted file mode 100644 index d5dc0d1d4..000000000 --- a/benchmarks/DASB/run_generative_benchmark.sh +++ /dev/null @@ -1,67 +0,0 @@ -#!/bin/bash -# Please consult the README.md file for instructions on how to run the benchmark. - -tokenizer_name=$1 -if [[ "$tokenizer_name" == "" ]]; then - echo "Usage: run_generative_benchmark.sh " - exit 1 -fi - -output_folder='path/to/output' -librimix_path='path/to/Libri2Mix' -voicebank_path='path/to/VoiceBank' -ljspeech_path='path/to/ljspeech' -utmos_path='path/to/utmos' -tts_args="--token_list_file_text %recipe_root%/hparams/char_en.txt --utmos_model_path $utmos_path" - -declare -a DatasetsFolders=(\ - "$librimix_path" \ - "$voicebank_path" \ - "$ljspeech_path" \ - "$ljspeech_path" \ -) -declare -a ConsideredTasks=(\ - 'Libri2Mix/separation' \ - 'VoiceBank/enhancement' \ - 'LJSpeech/TTS' \ - 'LJSpeech/TTS' \ -) -declare -a DownStreams=(\ - 'conformer' \ - 'conformer' \ - 'tokotron' \ - 'tokotron' \ -) -declare -a ExtraArgs=(\ - '' \ - '' \ - "$tts_args" \ - "$tts_args --enc_num_layers 3 --dec_num_layers 6" \ -) - -declare -a OutputSuffix=(\ - '' \ - '' \ - '' \ - '-small' -) - -shift -script_args="$@" - -for i in "${!ConsideredTasks[@]}"; do - task=${ConsideredTasks[i]} - downstream=${DownStreams[i]} - dataset_folder=${DatasetsFolders[i]} - extra_args=${ExtraArgs[i]} - suffix=${OutputSuffix[i]} - recipe_root="$task/$downstream" - recipe_extra_args="$script_args ${extra_args//%recipe_root%/$recipe_root}" - set -- "$recipe_extra_args" - echo "${tokenizer_name}/${task}/${downstream}" - python $task/$downstream/train_$tokenizer_name.py \ - $task/$downstream/hparams/train_$tokenizer_name.yaml \ - --output_folder $output_folder/$tokenizer_name/$task/$downstream$suffix \ - --data_folder $dataset_folder \ - $@ -done diff --git a/benchmarks/DASB/run_hparam_optimization.sh b/benchmarks/DASB/run_hparam_optimization.sh new file mode 100755 index 000000000..2ad1dddf3 --- /dev/null +++ b/benchmarks/DASB/run_hparam_optimization.sh @@ -0,0 +1,420 @@ +#!/bin/bash + +########################################################### +# Hyperparameter Tuning Script for EEG Model with Orion +########################################################### + +# Description: +# This script facilitates hyperparameter tuning for a given audio tokenizer, dowsnteram model and dataset using Orion. + +# Usage: +# ./run_hparam_optimization.sh --exp_name 'ASR-encodec-LSTM_hopt' \ + # --hparams LibriSpeech/ASR/hparams/LSTM/train.yaml \ + # --data_folder path/LibriSpeech \ + # --cached_data_folder path/cache/ \ + # --output_folder results/LibriSpeech/ASR/encodec/LSTM \ + # --task ASR \ + # --dataset LibriSpeech \ + # --seed 1986 \ + # --nruns 1 \ + # --nruns_eval 5 \ + # --eval_metric WER \ + # --exp_max_trials 50 \ + # --tokens_folder results/LibriSpeech/extraction-emb/encodec/save/librispeech/ \ + # --run_name encodec +# Optimization Steps: +# The script supports multiple hyperparameter optimization steps. + +# Script Workflow: +# 1. Search for the orion flags in the specified hparam file. +# 2. Run the orion-hunt command for hyperparameter tuning. +# By default, TPE (Tree-structured Parzen Estimator) hyperparameter tuning is +# performed, as specified in the default orion config file at hparams/orion/hparams_tpe.yaml. +# 3. Save the best hyperparameters, which can be viewed using torch-info. +# 4. Loop until flags like @orion_step are found in the YAML file. +# +# Final Performance Evaluation: +# At the end of the optimization process, the script computes the final performance +# using the best hyperparameters on the test set. +# This is done by averaging over nruns_eval different seeds. +# +# Note: More detailed information can be found in the README.md file. + +# Authors: +# - Pooneh Mousavi 2024 +########################################################### + +# Initialize variables +exp_name="hopt" +hparams="" +data_folder="" +cached_data_folder="" +output_folder="" +task="" +dataset="" +seed=1986 +nruns="" +nruns_eval=10 +eval_metric="acc" +config_file="orion/hparams_tpe.yaml" +mne_dir="" +orion_db_address="" +orion_db_type="PickledDB" +exp_max_trials=50 +store_all=True +compress_exp=True + +# Function to print argument descriptions and exit +print_argument_descriptions() { + echo "Usage: $0 [options]" + echo "Options:" + echo " --exp_name Name Name that Orion gives to the experiment" + echo " --hparms hparam_file YAML file containing the hparam to optimize. The hyperparameters decorated with @orion_step1 or @orion_step1 in the YAML file will be used" + echo " --data_folder data_path Folder were the data are stored. If not available, they will be downloaded there." + echo " --cached_data_folder path [Optional] Folder were the data in pkl format will be cached." + echo " --output_folder output_path Output folder were the results will be stored" + echo " --task task downstream task" + echo " --dataset dataset dataset" + echo " --seed random_seed [Optional] Seed (random if not specified)" + echo " --nruns num_runs Number of runs for each hparam selection." + echo " --nruns_eval num_runs Number of runs for the final evaluation (with best hparams) on the test set" + echo " --eval_metric metric [Optional] Evaluation metric description. Default:acc" + echo " --config_file config_file [Optional] Orion config file. Default: hparams/orion/hparams_tpe.yaml" + echo " --mne_dir mne_dir [Optional] MNE directory. Need it different from your home (see notes on MNE in README.md)" + echo " --orion_db_address [Optional] Path of the database where orion will store hparams and performance" + echo " --orion_db_type db_type [Optional] Type of the dataset that orion will use. Default: PickledDB" + echo " --exp_max_trials int [Optional] Maximum number of hparam trials for each oprimization step. Default:50" + echo " --store_all Bool [Optional] When set to True, the output folders of all hparam trials will be stored in randomly named folders. Default: False" + echo " --compress_exp Bool [Optional] When set to True, this option compresses the output folders of all hyperparameter trials into a single tar.gz file. This is particularly useful when store_all is set to True, as it helps prevent the accumulation of a large number of files. Default: False" + exit 1 +} + +POSITIONAL_ARGS=() + +while [[ $# -gt 0 ]]; do + case $1 in + + --exp_name) + exp_name="$2" + shift + shift + ;; + + --hparams) + hparams="$2" + shift + shift + ;; + + --data_folder) + data_folder="$2" + shift + shift + ;; + + --cached_data_folder) + cached_data_folder="$2" + shift + shift + ;; + + --output_folder) + output_folder="$2" + shift + shift + ;; + + --task) + task="$2" + shift + shift + ;; + + --dataset) + dataset="$2" + shift + shift + ;; + + --seed) + seed="$2" + shift + shift + ;; + + --nruns) + nruns="$2" + shift + shift + ;; + + --nruns_eval) + nruns_eval="$2" + shift + shift + ;; + + --eval_metric) + eval_metric="$2" + shift + shift + ;; + + --config_file) + config_file="$2" + shift + shift + ;; + + --mne_dir) + mne_dir="$2" + shift + shift + ;; + + --orion_db_address) + orion_db_address="$2" + shift + shift + ;; + + --orion_db_type) + orion_db_type="$2" + shift + shift + ;; + + --exp_max_trials) + exp_max_trials="$2" + shift + shift + ;; + + --store_all) + store_all="$2" + shift + shift + ;; + + --compress_exp) + compress_exp="$2" + shift + shift + ;; + + --help) + print_argument_descriptions + ;; + + -*|--*) + additional_flags+="$1 $2 " # store additional flags + shift # past argument + ;; + + + *) + POSITIONAL_ARGS+=("$1") # save positional arg + shift # past argument + ;; + esac +done + + +# Check for required arguments +if [ -z "$output_folder" ] || [ -z "$data_folder" ] || [ -z "$hparams" ] || [ -z "$nruns" ]; then + echo "ERROR: Missing required arguments! Please provide all required options." + print_argument_descriptions +fi + +# Set mne_dir if specified +if [ "$mne_dir" ]; then + export _MNE_FAKE_HOME_DIR=$mne_dir +fi + +# Assign default value to cached_data_folder +if [ -z "$cached_data_folder" ]; then + cached_data_folder="$data_folder/cache" +fi + + +# Set orion db address if specified +if [ -z "$orion_db_address" ]; then + orion_db_address=$output_folder'/'$exp_name'.pkl' +fi +export ORION_DB_ADDRESS=$orion_db_address +export ORION_DB_TYPE=$orion_db_type + +echo "-------------------------------------" +echo "Experiment Name: $exp_name" +echo "hparams: $hparams" +echo "Output Folder: $output_folder" +echo "Data Folder: $data_folder" +echo "Cached Data Folder: $cached_data_folder" +echo "task: $task" +echo "dataset: $dataset" +echo "Hparam File: $hparams" +echo "Number of Runs: $nruns" +echo "Number of Eval Runs: $nruns_eval" +echo "Eval Metric: $eval_metric" +echo "Seed: $seed" +echo "Additional Flags: $additional_flags" +echo "Orion Config File: $config_file" +echo "Orion Database type: $orion_db_type" +echo "Orion Database file: $orion_db_address" +echo "Experiment Max Trials: $exp_max_trials" +echo "-------------------------------------" + + +# This function will extract all the optimization flags added in the yaml file +# The input is a text file (e.g, a yaml file) and a pattern (e.g, "@orion_step1:") +# The ouput are the detected flags (e.g., --dropout~"uniform(0.0, 0.5)"). +get_flag() { + local file_path="$1" + local pattern="$2" + + # Check if the file exists + if [ ! -f "$file_path" ]; then + echo "Error: File '$file_path' not found." + return 1 + fi + + # Use grep to find all lines containing the pattern and then extract the flags using sed + grep -o "$pattern.*" "$file_path" | sed "s/$pattern//" | tr -d '\n' +} + + +# Function for updatading the hparam yaml file with the best hparams found at step 1 +update_hparams() { + local best_hparams_file="$1" + local hparams_yaml_file="$2" + local output_yaml_file="$3" + + # Read the values from best_hparams.txt into an associative array + declare -A best_hparams + while IFS=": " read -r key value; do + best_hparams["$key"]=$value + done < "$best_hparams_file" + + + # Read the hparams.yaml file into a variable + local hparams_content=$(cat "$hparams_yaml_file") + + # Update values in hparams_content using values from best_hparams + for key in "${!best_hparams[@]}"; do + local pattern="^$key: .*" + local replacement="$key: ${best_hparams[$key]}" + hparams_content=$(sed "s/$pattern/$replacement/g" <<< "$hparams_content") + done + + # Write the updated content to a new YAML file + echo "$hparams_content" > "$output_yaml_file" +} + +# Function for extracting the best hparams from orion-info +function extract_best_params() { + local input_file="$1" + local best_trial_line=$(grep -n "best trial:" "$input_file" | cut -d ":" -f 1) + local params_lines=$(tail -n +$best_trial_line "$input_file" | awk '/params:/{flag=1;next}/start time:/{flag=0}flag') + local formatted_params=$(echo "$params_lines" | sed -e 's/^[[:space:]]*//' -e 's/: /: /' -e '/^$/d' -e 's#^/##') + echo "$formatted_params" +} + +# Running hparam tuning (loop over multiple steps) +step_id=1 +hparams_step=$hparams +pattern="@orion_step1:" +opt_flags=$(get_flag "$hparams_step" "$pattern") + +# Check if the string is empty and exit with an error if it is +if [ -z "$opt_flags" ]; then + echo "Error: Optimization flags not found in '$hparams'" + echo "Please ensure that the Orion optimization flags are set in the hparam file using in-line comments like:" + echo "# @orion_step1: --dropout~\"uniform(0.0, 0.5)\"" + exit 1 # Exit with a non-zero error code +fi + + +while [ -n "$opt_flags" ]; do + # Do something + output_folder_step="$output_folder"/step"$step_id" + mkdir -p $output_folder_step + exp_name_step="$exp_name"_step"$step_id" + + echo + echo "**********************************************************************************************" + echo "Running hparam tuning (step $step_id)..." + echo "- This might take several hours!" + echo "- The best set of hparams will be save in $output_folder_step" + echo "- You can monitor the evolution of the hparam optimization with: orion status -n $exp_name" + echo "......" + echo "**********************************************************************************************" + echo + # Setting up orion command + orion_hunt_command="orion hunt -n $exp_name_step -c $config_file --exp-max-trials $exp_max_trials \ + ./run_experiments.sh --hparams $hparams_step --data_folder $data_folder --cached_data_folder $cached_data_folder \ + --output_folder $output_folder_step/exp --task $task --dataset $dataset --seed $seed --nruns $nruns \ + --eval_metric $eval_metric --eval_set dev --rnd_dir $store_all --testing False $additional_flags" + + + # Appending the optimization flags + orion_hunt_command="$orion_hunt_command $opt_flags" + + echo $orion_hunt_command &> "$output_folder_step/orion_hunt_command.txt" + + # Execute the command for hparm tuning + eval $orion_hunt_command + + # Compress the exp folder (if required) + if [ "$compress_exp" = True ]; then + tar -czf "$output_folder_step/exp.tar.gz" "$output_folder_step/exp" + if [ -d "$output_folder_step/exp" ]; then + rm -rf "$output_folder_step/exp" + fi + + fi + + # Storing best haprams + orion info --name $exp_name_step &> $output_folder_step/orion-info.txt + + # Extract list of the best hparams from orion-info + # Find the line number where "best trial:" appears + best_trial_line=$(grep -n "best trial:" $output_folder_step/orion-info.txt | cut -d ":" -f 1) + + # Extract and store the best set of hparams + best_params_output=$(extract_best_params "$output_folder_step/orion-info.txt") + best_hparams_file="$output_folder_step/best_hparams.txt" + echo "$best_params_output" > $best_hparams_file + + # Store the current best yaml file + best_yaml_file="$output_folder_step/best_hparams.yaml" + update_hparams "$best_hparams_file" "$hparams_step" "$best_yaml_file" + + # Update best hparam step + hparams_step=$best_yaml_file + + # Update step variable + ((step_id++)) + + # Update search pattern + pattern="@orion_step$step_id:" + + # update optimization flags pattern + opt_flags=$(get_flag "$hparams_step" "$pattern") +done + +echo +echo "**********************************************************************************************" +echo "Running Final Evaluation on the best hparams (test-set)..." +echo "**********************************************************************************************" +echo + +final_yaml_file="$output_folder/best_hparams.yaml" +scp $best_yaml_file $final_yaml_file + +# Running evaluation on the test set for the best models +./run_experiments.sh --hparams $final_yaml_file --data_folder $data_folder --cached_data_folder $cached_data_folder \ + --output_folder $output_folder/best --task $task --dataset $dataset --seed $seed\ + --nruns $nruns_eval --eval_metric $eval_metric --eval_set test \ + --rnd_dir $store_all --testing True $additional_flags + +echo "The test performance with best hparams is available at $output_folder/best" \ No newline at end of file diff --git a/benchmarks/DASB/utils/aggregate_results.py b/benchmarks/DASB/utils/aggregate_results.py new file mode 100644 index 000000000..0df315b7e --- /dev/null +++ b/benchmarks/DASB/utils/aggregate_results.py @@ -0,0 +1,149 @@ +#!/usr/bin/python +""" +Snippet to aggregate the results over multiple runs of the same experiment. +This is useful when we run multiple experiments with different seeds and we +want to compute the average performance. The script also reports the final +metric to Orion (when needed for hyperparameter tuning). + +The script searches for the result files (_results.txt) and computes the mean +and the standard deviation of the given evaluation metrics (e.g., acc or f1). +The results must have an identical format (with only different performance +numbers). + +To run this script: + + > python aggregate_results.py your_result_folder acc + +Author +------ +Pooneh Mousavi 2024 +""" + +import sys +import re +import numpy as np +from orion.client import report_objective +from speechbrain.utils.data_utils import get_all_files + + +def get_prototype(res_file, eval_metric): + """Parses a result file and adds a placeholder where the aggregated metrics + should be printed. It also returns the number of detected metrics. + + Arguments + --------- + res_file: path + Path of the result file to parse. + eval_metric: path + Metric of interest (e.g, acc or f1). + + Returns + --------- + prototype: list + List of the lines of the result file (with as placeholder). + n_metrics: int + Number of metrics to replace in the result files. + """ + prototype = [] + n_metrics = 0 + + # Open the first res file and figure out where the metrics are + with open(res_file) as file_in: + for line in file_in: + if eval_metric in line: + line = line.split(eval_metric)[0] + # The placeholder for the metric is + line = line + eval_metric + " " + n_metrics = n_metrics + 1 + prototype.append(line) + return prototype, n_metrics + + +def get_metrics(res_files, eval_metric): + """Summarizes the metrics of interest in a matrix. + + Arguments + --------- + res_files: list + List of all the result files. + eval_metric: path + Metric of interest (e.g, acc or f1). + + Returns + --------- + metrics: np.array + Matrix (n_metrics, n_files) containing the metrics of interest. + """ + + # Metric initialization + metrics = np.zeros([n_metrics, len(res_files)]) + + # Loop over files + for i in range(len(res_files)): + cnt = 0 + # Metric extraction + with open(res_files[i]) as file_in: + for line in file_in: + if eval_metric in line: + # Use regex to find the test WER value + match = re.search( + rf"{eval_metric}: (\d+\.\d+(?:e[+-]?\d+)?)", line + ) + if match: + value = match.group(1) + value = float(value) + metrics[cnt, i] = value + cnt = cnt + 1 + return metrics + + +def aggregate_metrics(prototype, metrics): + """Prints the aggregated metrics.It replaces the placeholders with + the corresponding metrics. + + Arguments + --------- + prototype: list + List of the lines of the result file (with as placeholder). + metrics: np.array + Matrix (n_metrics, n_files) containing the metrics of interest. + """ + cnt = 0 + for line in prototype: + if eval_metric in line: + values_line = "[" + for i in range(len(res_files)): + values_line = values_line + "%f " % float(metrics[cnt, i]) + values_line = values_line[:-1] + values_line = values_line + "] avg: %f ± %f " % ( + float(metrics[cnt, :].mean()), + float(metrics[cnt, :].std()), + ) + line = line.replace("", values_line) + cnt = cnt + 1 + print(line) + + +if __name__ == "__main__": + output_folder = sys.argv[1] + eval_metric = sys.argv[2] + + # Getting the list of the result files in the output folder + res_files = get_all_files(output_folder, match_and=["train_log.txt"]) + + # Gettin a prototype file + prototype, n_metrics = get_prototype(res_files[0], eval_metric) + + # Extracting the metrics of interest + metrics = get_metrics(res_files, eval_metric) + + # print aggregated metrics + aggregate_metrics(prototype, metrics) + + final_metric = metrics[-1, :].mean() + + # Report final metric to Orion + # Remember: orion expects metrics to be minimized! + if eval_metric == "acc" or eval_metric == "f1": + final_metric = 1 - final_metric + report_objective(final_metric) diff --git a/benchmarks/DASB/utils/tokenizer_interface.py b/benchmarks/DASB/utils/tokenizer_interface.py new file mode 100644 index 000000000..ff1194968 --- /dev/null +++ b/benchmarks/DASB/utils/tokenizer_interface.py @@ -0,0 +1,255 @@ +""" +Unified interface for tokenizers, standardizing the output shape of encode and decode functions. + +This class reshapes the outputs of various tokenizers to ensure consistency, simplifying integration with recipes and workflows. + +Authors +--------- +* Pooneh Mousavi, 2024 +""" + +import torch +from abc import ABC, abstractmethod +from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec +from speechbrain.lobes.models.huggingface_transformers.discrete_ssl import ( + DiscreteSSL, +) +from speechbrain.lobes.models.discrete.dac import DAC +from speechbrain.lobes.models.discrete.speechtokenizer_interface import ( + SpeechTokenizer_interface, +) + + +class BaseTokenizer(ABC): + """ + Abstract base class for tokenizers that encode signals into discrete tokens + and decode tokens back into signals. + + This class defines the essential methods that any tokenizer must implement, + including encoding, decoding, and retrieving pretrained embeddings. + + Naming Convenstion + ------------------ + B : int + Batch size. + T : int + Sequence length in the time domain. + N : int + Sequence length in the token domain. + C : int + Vocabulary size, assuming each codebook has the same number of tokens. + K : int + Number of codebooks. + """ + + def __init__(self): + """ + Initialize the BaseTokenizer. + + This is a base constructor that other tokenizers can extend. + """ + super().__init__() + + @abstractmethod + @torch.no_grad() + def sig_to_tokens(self, signal, lengths, num_codebooks=None, **kwargs): + """ + Encode a signal into discrete tokens. + + Arguments + --------- + signal : torch.Tensor + Input signal with shape [B, T]. + lengths : torch.Tensor + Lengths of each sequence in the batch, with shape [B]. + num_codebooks : int, optional + Number of codebooks to use for encoding. If None, all codebooks are used (default: None). + If specified as an int, the tokens will be truncated to include only the first `num_codebooks` codebooks. If specified as a list, + the tokens will include only the codebooks at the specified indices. + **kwargs : dict + Additional arguments for the tokenizer. + + Returns + ------- + tokens : torch.Tensor + Discretized tokens with shape [B, N, K]. + """ + pass + + @abstractmethod + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + """ + Decode discrete tokens back into a signal. + + Arguments + --------- + tokens : torch.Tensor + Input tokens with shape [B, N, K]. + **kwargs : dict + Additional arguments for the tokenizer. + + Returns + ------- + signal : torch.Tensor + Reconstructed signal with shape [B, T]. + """ + pass + + @abstractmethod + @torch.no_grad() + def get_pretrained_embeddings(self, vocab_size, num_codebooks, **kwargs): + """ + Retrieve pretrained embeddings for the tokenizer. + + Arguments + --------- + vocab_size : int + Number of tokens in each codebook. + num_codebooks : int + Number of codebooks. + **kwargs : dict + Additional arguments for embedding retrieval. + + Returns + ------- + embeddings : torch.Tensor + Pretrained embedding weights with shape [K, C, H], where H is the embedding dimension. + """ + pass + + +class EncodecTokenizer(Encodec, BaseTokenizer): + def __init__(self, *args, **kwargs): + Encodec.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths, num_codebooks=None, **kwargs): + self.eval() + tokens, _ = self.encode(signal, lengths) + if num_codebooks: + if tokens.shape[-1] < num_codebooks: + raise ValueError( + f"Model only outputs {tokens.shape[-1]} codebooks, but {num_codebooks} requested" + ) + tokens = tokens[..., :num_codebooks] + return tokens + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + signal = self.decode(tokens)[:, 0] + return signal + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + embeddings = self.vocabulary + return embeddings.reshape(-1, embeddings.shape[-1]) + + +class DACTokenizer(DAC, BaseTokenizer): + def __init__(self, *args, **kwargs): + DAC.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths, num_codebooks=None, **kwargs): + self.eval() + tokens, _ = self(signal[:, None], n_quantizers=num_codebooks) + return tokens.movedim(-1, -2) + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + quantized_feats, _, _ = self.quantizer.from_codes( + tokens.movedim(-1, -2) + ) + return self.decode(quantized_feats)[:, 0] + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + toks = torch.arange(vocab_size).to(next(self.parameters()).device) + toks = toks[:, None, None].expand(-1, num_codebooks, -1).clone() + self.eval() + z_q, z_p, _ = self.quantizer.from_codes(toks) + z_ps = z_p.split(z_p.shape[1] // toks.shape[1], dim=1) + z_qs = [ + self.quantizer.quantizers[i].out_proj(z_p_i) + for i, z_p_i in enumerate(z_ps) + ] + return torch.cat(z_qs)[:, :, 0] + + +class SpeechTokenizer(SpeechTokenizer_interface, BaseTokenizer): + def __init__(self, *args, **kwargs): + SpeechTokenizer_interface.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + self.sample_rate = 16000 + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths, num_codebooks=None, **kwargs): + self.eval() + tokens = self(signal) + if num_codebooks: + if len(tokens) < num_codebooks: + raise ValueError( + f"Model only outputs {len(tokens)} codebooks, but {num_codebooks} requested" + ) + tokens = tokens[:num_codebooks] + return tokens.movedim(-3, -1) + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + return self.decode(tokens.movedim(-1, -3)) + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + toks = torch.arange(vocab_size).to(next(self.parameters()).device) + toks = toks[None, :, None].expand(num_codebooks, -1, -1).clone() + self.eval() + embs = [ + self.model.quantizer.vq.layers[i].decode(indices) + for i, indices in enumerate(toks) + ] + return torch.cat(embs)[:, :, 0] + + +class DiscreteSSLTokenizer(DiscreteSSL, BaseTokenizer): + def __init__(self, *args, **kwargs): + DiscreteSSL.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths, num_codebooks=None, **kwargs): + self.eval() + tokens, _, _ = self.encode( + signal, lengths, SSL_layers=num_codebooks, **kwargs + ) + return tokens + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + return self.decode(tokens, **kwargs) + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + embs = [] + for layer_num, vocabulary in zip( + self.ssl_layer_ids, self.vocabularies, + ): + if layer_num not in num_codebooks: + continue + embs.append(torch.as_tensor(vocabulary, dtype=torch.float32)) + embs = torch.cat(embs) + return embs diff --git a/benchmarks/DASB/utils/tokens.py b/benchmarks/DASB/utils/tokens.py new file mode 100644 index 000000000..d55d8d449 --- /dev/null +++ b/benchmarks/DASB/utils/tokens.py @@ -0,0 +1,411 @@ +""" +Unified interface for token extraction and pretrained embeddings handling for speech tokenizers. + +Authors +--------- +* Jarod Duret, 2024 +""" + +import math +import logging +import pathlib as pl +import kaldiio +import torch +import torchaudio +import numpy as np +from tqdm.auto import tqdm +import speechbrain as sb +from speechbrain.dataio.dataloader import make_dataloader +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.dataio.dataio import load_pkl, save_pkl + + +logger = logging.getLogger(__name__) +OPT_FILE = "opt_extract.pkl" + + +def get_device(use_cuda): + logger.info("=" * 30) + logger.info(f"USE_CUDA SET TO: {use_cuda}") + logger.info(f"CUDA AVAILABLE?: {torch.cuda.is_available()}") + logger.info("=" * 30) + use_cuda = use_cuda and torch.cuda.is_available() + return torch.device("cuda" if use_cuda else "cpu") + + +class TokensExtractor: + """ + Extracts tokens from audio data using a tokenizer and saves them to a specified format. + + Arguments + --------- + tokenizer : torch.nn.Module + The tokenizer model to use for token extraction. + sample_rate : int + The sample rate of the audio data. + src_key : str, optional + The key in the dataset that contains the audio data (default: "wav"). + id_key : str, optional + The key in the dataset that contains unique identifiers (default: "id"). + save_format : str, optional + The format to save the tokens ('numpy', 'pickle', 'soundfile_flac') (default: "numpy"). + use_cuda : bool, optional + Whether to use CUDA for computation (default: True). + dataloader_opts : dict, optional + Options for the data loader (default: None). + + Raises + ------ + ValueError + If an unsupported save_format is provided. + ValueError + If the tokenizer's sample rate does not match the provided sample_rate. + """ + + def __init__( + self, + tokenizer, + sample_rate, + src_key="wav", + id_key="id", + save_format="numpy", + use_cuda=True, + dataloader_opts=None, + ): + self.id_key = id_key + self.src_key = src_key + + self.device = get_device(use_cuda) + self.tokenizer = tokenizer.to(self.device) + self.sample_rate = sample_rate + + if tokenizer.sample_rate != self.sample_rate: + raise ValueError( + f"Sample rate mismatch: {self.sample_rate} != {tokenizer.sample_rate}" + ) + + if save_format not in ["numpy", "pickle", "soundfile_flac"]: + raise ValueError(f"Unsupported save_format: {save_format}") + self.save_format = save_format + + if not dataloader_opts: + dataloader_opts = {} + self.dataloader_opts = dataloader_opts + self.pipelines = self._make_pipelines() + + def extract_tokens( + self, dataset, num_codebooks, save_path, save_name="tokens" + ): + """ + Extracts tokens from the dataset and saves them to the specified format. + + Arguments + --------- + dataset : speechbrain.dataio.dataset.DynamicItemDataset or dict + The dataset from which to extract tokens. Can be a DynamicItemDataset or a dictionary. + num_codebooks: int + The number of codebooks to retrieve from the tokens. + save_path: str + The path where tokens will be saved. + save_name: str + The name of the .scp and .ark files. + """ + conf = { + "sample_rate": self.sample_rate, + "save_folder": save_path, + "dataset_length": len(dataset), + } + + save_path = pl.Path(save_path).absolute() + save_path.mkdir(parents=True, exist_ok=True) + + # Check if the extraction is already done (if so, skip it) + if _skip(save_path, save_name, conf): + logger.info("Skipping extraction, completed in previous run.") + return + + self.wspecifier = ( + f"ark,scp,t:{save_path}/{save_name}.ark,{save_path}/{save_name}.scp" + ) + self.writer = kaldiio.WriteHelper( + self.wspecifier, write_function="numpy" + ) + + if isinstance(dataset, dict): + dataset = DynamicItemDataset(dataset) + dataset.set_output_keys([self.src_key, self.id_key, "sig"]) + for pipeline in self.pipelines: + dataset.add_dynamic_item(pipeline) + + dataloader = make_dataloader(dataset, **self.dataloader_opts) + batch_size = self.dataloader_opts.get("batch_size", 1) + batch_count = int(math.ceil(len(dataset) / batch_size)) + for batch in tqdm(dataloader, total=batch_count): + batch = batch.to(self.device) + x, x_lengths = batch["sig"] + ids = batch[self.id_key] + batch_tokens = self.tokenizer.sig_to_tokens( + x, x_lengths, num_codebooks=num_codebooks + ) + batch_tokens = sb.utils.data_utils.undo_padding( + batch_tokens, x_lengths + ) + self.process_batch(batch_tokens, ids) + + logger.info("Extraction completed.") + + save_opt = save_path / OPT_FILE + save_pkl(conf, save_opt.as_posix()) + + def process_batch(self, batch, ids): + """ + Processes a batch of tokens and writes them to the output files. + + Arguments + --------- + batch : list + A list of tokens for each item in the batch. + ids : list + A list of unique identifiers corresponding to each item in the batch. + """ + for tokens, utt_id in zip(batch, ids): + tokens = np.array(tokens) + self.writer(utt_id, tokens) + + def _make_pipelines(self): + """ + Creates the data processing pipeline for audio data. + + The pipeline reads audio files, resamples them to the desired sample rate, and provides + the processed signal under the key "sig". + + Returns + ------- + pipeline : list + A list containing the audio processing pipeline function. + """ + + @sb.utils.data_pipeline.takes(self.src_key) + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + info = torchaudio.info(wav) + sig = sb.dataio.dataio.read_audio(wav) + sig = torchaudio.transforms.Resample( + info.sample_rate, self.sample_rate, + )(sig) + return sig + + return [audio_pipeline] + + def save_pretrained_embeddings( + self, + save_path, + save_name="embeddings", + vocab_size=None, + num_codebooks=None, + ): + """ + Saves the pretrained embeddings of the tokenizer to a specified directory. + + This method retrieves the pretrained embeddings from the tokenizer, + converts them to a NumPy array, and saves them as a `.npy` file. + + Parameters + ---------- + save_path : str or pathlib.Path + The directory where the pretrained embeddings will be saved. + If the directory does not exist, it will be created. + save_name : str, optional + The base name of the saved embeddings file (default is "embeddings"). + The embeddings will be saved as `.npy` in the specified directory. + """ + save_path = pl.Path(save_path).absolute() + save_path.mkdir(parents=True, exist_ok=True) + + embeddings = self.tokenizer.get_pretrained_embeddings( + vocab_size, num_codebooks + ) + embeddings = embeddings.cpu().numpy() + np.save(save_path / save_name, embeddings) + + def __del__(self): + """ + Close the writer. + """ + self.writer.close() + + +def _skip(save_path, save_name, conf): + """ + Detects if the dataset extraction has been already done. + If the extraction has been done, we can skip it. + + Arguments + --------- + save_path : str + The path to the directory containing extracted tokens. + save_name : str + The base name of the saved tokens file. + conf : dict + Configuration to match against saved config. + + Returns + ------- + bool + if True, the preparation phase can be skipped. + if False, it must be done. + """ + skip = True + + # Checking ark,scp files + for ext in [".ark", ".scp"]: + save_file = save_path / f"{save_name}{ext}" + if not save_file.exists: + skip = False + + # Checking saved options + save_opt = save_path / OPT_FILE + if skip is True: + if save_opt.exists(): + opts_old = load_pkl(save_opt.as_posix()) + if opts_old == conf: + skip = True + else: + skip = False + else: + skip = False + return skip + + +class TokensLoader: + """ + A loader class for retrieving tokens corresponding to utterance IDs. + + Arguments + --------- + data_path: str + The path to the data directory containing the token files. + save_name: str, optional + The base name of the tokens files (default: "tokens"). + """ + + def __init__( + self, data_path, save_name="tokens", + ): + self.data_path = pl.Path(data_path) + if not self.data_path.exists(): + raise ValueError( + f"Data folder not found: {self.data_path.as_posix()}" + ) + self.tokens = self._load(data_path, save_name) + + def tokens_by_uttid(self, utt_id, num_codebooks=None): + """ + Retrieves the tokens corresponding to a given utterance ID. + + Arguments + --------- + utt_id : str + The utterance ID to retrieve tokens for. + num_codebooks : int or list, optional + The number of codebooks to retrieve from the tokens. If specified as an int, the tokens + will be truncated to include only the first `num_codebooks` codebooks. If specified as a list, + the tokens will include only the codebooks at the specified indices. If not specified, all codebooks are returned. + + Returns + ------- + result : torch.LongTensor [T, N_Q] + The tokens associated with the utterance ID, possibly truncated to `num_codebooks` codebooks. + + Raises + ------ + KeyError + If the utterance ID is not found in the tokens. + ValueError + If `num_codebooks` is invalid or exceeds the number of available codebooks. + """ + if utt_id not in self.tokens: + raise KeyError(f"Utterance ID '{utt_id}' not found in tokens.") + tokens_path = self.tokens[utt_id] + tokens = kaldiio.load_mat(tokens_path) + tokens = torch.from_numpy(tokens).long() + + if num_codebooks is not None: + if isinstance(num_codebooks, int): + if num_codebooks <= 0: + raise ValueError( + f"Invalid num_codebooks value: {num_codebooks}. It must be a positive integer." + ) + if num_codebooks > tokens.size(-1): + raise ValueError( + f"Invalid number of codebooks: {num_codebooks}. " + f"Available codebooks: {tokens.size(-1)}." + ) + tokens = tokens[:, :num_codebooks] + elif isinstance(num_codebooks, list): + if not all( + isinstance(idx, int) and 0 <= idx < tokens.size(-1) + for idx in num_codebooks + ): + raise ValueError( + f"Invalid indices in num_codebooks list: {num_codebooks}. " + f"All indices must be integers within the range [0, {tokens.size(-1) - 1}]." + ) + tokens = tokens[:, num_codebooks] + else: + raise ValueError("num_codebooks must be an int or a list.") + + return tokens + + def _load(self, data_path, save_name): + """ + Loads the mapping from utterance IDs to token file paths. + + Arguments + --------- + data_path: str + The path to the data directory containing the token files. + save_name: str + The base name of the tokens files. + + Returns + ------- + utt2toks: dict + A dictionary mapping utterance IDs to their corresponding token file paths. + """ + scp_path = f"{data_path}/{save_name}.scp" + with open(scp_path, "r") as f: + utt2toks = { + line.strip().split(None, 1)[0]: line.strip().split(None, 1)[1] + for line in f + if line.strip() + } + return utt2toks + + def load_pretrained_embeddings(self, data_path, save_name="embeddings"): + """ + Loads pretrained embeddings from a specified path. + + Arguments + --------- + data_path : str + The directory where the embeddings are saved. + save_name : str, optional + The name of the embeddings file (default: "embeddings"). + + Returns + ------- + embeddings : torch.Tensor + The loaded embeddings as a PyTorch tensor. + + Raises + ------ + FileNotFoundError + If the embeddings file does not exist at the specified path. + """ + data_path = pl.Path(data_path).absolute() + if not self.data_path.exists(): + raise ValueError(f"Data folder not found: {data_path.as_posix()}") + embeddings = np.load(data_path / f"{save_name}.npy") + embeddings = torch.from_numpy(embeddings) + return embeddings \ No newline at end of file diff --git a/speechbrain b/speechbrain index e602161f4..ecb34d8fa 160000 --- a/speechbrain +++ b/speechbrain @@ -1 +1 @@ -Subproject commit e602161f4d305e13a26fc71b7dbe4a4cfeaa8847 +Subproject commit ecb34d8fa4fb0aa8d9888e33478ec54064f25ff8