-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
250 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,118 @@ | ||
from resemblyzer import preprocess_wav, VoiceEncoder | ||
""" | ||
Diarize audio using Resemblyzer (https://github.com/resemble-ai/Resemblyzer) | ||
This script is adapted from demo 2 in the Resemblyzer repo. | ||
Specify the audio segments for doctor and patient with start and end time, | ||
connected with a dash. Add space between multiple audio segments. | ||
E.g. --doctor_segments 3.2-5.5 10-12.1 | ||
Other arguments | ||
* partials_n_frames: The number of frames used per prediction | ||
* audio_embed_rate: How many times per second to make predictions | ||
* speaker_embed_rate: How many times per second to embed the audio | ||
segments for each speaker | ||
Every <frame_step> # of frames, partials_n_frames # of frames are | ||
encoded to make a prediction. The prediction seems to be most associated | ||
with the centre of that time interval. | ||
""" | ||
|
||
import argparse | ||
import json | ||
import os | ||
import numpy as np | ||
|
||
import resemblyzer | ||
from resemblyzer import preprocess_wav, VoiceEncoder, sampling_rate | ||
from resemblyzer.hparams import mel_window_step | ||
|
||
|
||
def secs_per_partial(args): | ||
""" | ||
Samples per partial (n frames encoded for a single prediction) | ||
in seconds | ||
""" | ||
return args.partials_n_frames * mel_window_step / 1000 | ||
|
||
|
||
def samples_between_preidctions(args): | ||
""" | ||
# of frames per partial | ||
""" | ||
return int(np.round((sampling_rate / rate) / samples_per_frame)) | ||
|
||
|
||
def print_predictions(speaker_predictions, wav_splits, args, freq=2): | ||
interval = int(args.audio_embed_rate / freq) | ||
for i in range(0, len(speaker_predictions), interval): | ||
midpoint_offset = (wav_splits[i].stop - wav_splits[i].start) / sampling_rate / 2 | ||
seconds = (wav_splits[i].start / sampling_rate) + midpoint_offset | ||
print("{}m:{}s | Speaker: {} | Doctor Conf: {}".format( | ||
int(seconds / 60), round(seconds % 60, 1), | ||
speaker_predictions[i], similarity_matrix[0, i])) | ||
|
||
|
||
def output_json(speaker_predictions, wav_splits, args): | ||
diarization = [] | ||
for i in range(len(speaker_predictions)): | ||
midpoint_offset = (wav_splits[i].stop - wav_splits[i].start) / sampling_rate / 2 | ||
seconds = (wav_splits[i].start / sampling_rate) + midpoint_offset | ||
speaker_prediction = speaker_predictions[i] | ||
diarization.append({ | ||
"time": round(seconds, 2), | ||
"speaker": speaker_prediction | ||
}) | ||
|
||
filename_prefix = os.path.splitext(os.path.basename(args.audio_file))[0] | ||
json_filename = filename_prefix + '-diarization.json' | ||
with open(json_filename, 'w') as f: | ||
json.dump({"diarization": diarization}, f, indent=4) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--audio_file", type=str, required=True) | ||
parser.add_argument("--doctor_segments", type=str, nargs='+', required=True,) | ||
parser.add_argument("--patient_segments", type=str, nargs='+', required=True) | ||
parser.add_argument("--partials_n_frames", type=int, default=160, required=False) | ||
parser.add_argument("--audio_embed_rate", type=int, default=16, required=False) | ||
parser.add_argument("--speaker_embed_rate", type=int, default=16, required=False) | ||
parser.add_argument("--model_file", type=str, default="pretrained.pt", required=False) | ||
args = parser.parse_args() | ||
|
||
# Set partials_n_frames (number of frames region per prediction) | ||
resemblyzer.voice_encoder.partials_n_frames = args.partials_n_frames | ||
wav = preprocess_wav(args.audio_file) | ||
|
||
# Cut some segments from each speaker as reference audio | ||
num_doctor_segments = len(args.doctor_segments) | ||
num_patient_segments = len(args.patient_segments) | ||
doctor_segments = [[float(time.split('-')[0]), float(time.split('-')[1])] | ||
for time in args.doctor_segments] | ||
patient_segments = [[float(time.split('-')[0]), float(time.split('-')[1])] | ||
for time in args.patient_segments] | ||
segments = doctor_segments + patient_segments | ||
speaker_names = (["Doctor"] * num_doctor_segments) + (["Patient"] * num_patient_segments) | ||
speaker_wavs = [wav[int(s[0] * sampling_rate):int(s[1]) * sampling_rate] for s in segments] | ||
|
||
# Encode the audio | ||
print("Encode the audio...") | ||
encoder = VoiceEncoder("cpu", model_file=args.model_file) | ||
_, cont_embeds, wav_splits = encoder.embed_utterance(wav, return_partials=True, rate=args.audio_embed_rate) | ||
speaker_embeds = [encoder.embed_utterance(speaker_wav, rate=args.speaker_embed_rate) for speaker_wav in speaker_wavs] | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--audio_file", type=str, required=True) | ||
parser.add_argument("--doctor_segments", type=str, nargs='+', required=True) | ||
parser.add_argument("--patient_segments", type=str, nargs='+', required=True) | ||
parser.add_argument("--partials_n_frames", default=160, type=int, required=False) | ||
args = parser.parse_args() | ||
|
||
time_per_partial_frame = args.partials_n_frames * 10 / 1000 * 16000 | ||
wav = preprocess_wav(args.audio_file) | ||
|
||
num_doctor_segments = len(args.doctor_segments) | ||
num_patient_segments = len(args.patient_segments) | ||
doctor_segments = [[float(time.split('-')[0]), float(time.split('-')[1])] | ||
for time in args.doctor_segments] | ||
patient_segments = [[float(time.split('-')[0]), float(time.split('-')[1])] | ||
for time in args.patient_segments] | ||
|
||
# Cut some segments from single speakers as reference audio | ||
segments = doctor_segments + patient_segments | ||
speaker_names = (["Doctor"] * num_doctor_segments) + (["Patient"] * num_patient_segments) | ||
speaker_wavs = [wav[int(s[0] * sampling_rate):int(s[1]) * sampling_rate] for s in segments] | ||
|
||
rate = 16 | ||
encoder = VoiceEncoder("cpu") | ||
print("Running the continuous embedding on cpu, this might take a while...") | ||
_, cont_embeds, wav_splits = encoder.embed_utterance(wav, return_partials=True, rate=rate) | ||
|
||
speaker_embeds = [encoder.embed_utterance(speaker_wav) for speaker_wav in speaker_wavs] | ||
similarity_dict = {name: cont_embeds @ speaker_embed for name, speaker_embed in | ||
zip(speaker_names, speaker_embeds)} | ||
similarity_matrix = np.array([cont_embeds @ speaker_embed for name, speaker_embed in | ||
zip(speaker_names, speaker_embeds)]) | ||
|
||
speaker_predictions_indexes = np.argmax(similarity_matrix, axis=0) | ||
speaker_predictions = [speaker_names[index] for index in speaker_predictions_indexes] | ||
times_per_sec = 2 | ||
interval = int(rate / times_per_sec) | ||
for i in range(0, len(speaker_predictions), interval): | ||
seconds = i * (960 / 16000) | ||
seconds += time_per_partial_frame / 2 / 16000 # Centre prediction at middle of interval | ||
print("{}m:{}s | Speaker: {} | Doctor Conf: {}".format(int(seconds / 60), round(seconds % 60, 1), speaker_predictions[i], similarity_matrix[0, i])) | ||
|
||
output_json = { | ||
"diarization": [] | ||
} | ||
# Produce output JSON data | ||
for i in range(len(speaker_predictions)): | ||
seconds = i * (960 / 16000) | ||
seconds += time_per_partial_frame / 2 / 16000 # Centre prediction at middle of interval | ||
seconds = round(seconds, 2) | ||
speaker_prediction = speaker_predictions[i] | ||
output_json["diarization"].append({ | ||
"time": seconds, | ||
"speaker": speaker_prediction | ||
}) | ||
|
||
filename_prefix = os.path.splitext(os.path.basename(args.audio_file))[0] | ||
json_filename = filename_prefix + '-diarization.json' | ||
with open(json_filename, 'w') as f: | ||
json.dump(output_json, f, indent=4) | ||
# Determine who spoke when | ||
similarity_dict = {name: cont_embeds @ speaker_embed for name, speaker_embed in | ||
zip(speaker_names, speaker_embeds)} | ||
similarity_matrix = np.array([cont_embeds @ speaker_embed for name, speaker_embed in | ||
zip(speaker_names, speaker_embeds)]) | ||
speaker_predictions_indexes = np.argmax(similarity_matrix, axis=0) | ||
speaker_predictions = [speaker_names[index] for index in speaker_predictions_indexes] | ||
|
||
# Print predictions | ||
print_predictions(speaker_predictions, wav_splits, args) | ||
|
||
# Produce output JSON data | ||
output_json(speaker_predictions, wav_splits, args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import argparse | ||
import json | ||
import os | ||
import numpy as np | ||
from copy import deepcopy | ||
|
||
import resemblyzer | ||
from resemblyzer import preprocess_wav, VoiceEncoder, sampling_rate | ||
from rev_diarization import diarize_transcript_elements, parse_transcript_elements | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--audio_file", type=str, required=True) | ||
parser.add_argument("--doctor_segments", type=str, nargs='+', required=True) | ||
parser.add_argument("--patient_segments", type=str, nargs='+', required=True) | ||
parser.add_argument("--transcript_file", type=str, required=True) | ||
parser.add_argument("--gt_diarization_file", type=str, required=True) | ||
args = parser.parse_args() | ||
|
||
with open(args.transcript_file, 'r') as f: | ||
transcript_file = json.load(f) | ||
transcript_elements = parse_transcript_elements(transcript_file) | ||
|
||
with open(args.gt_diarization_file, 'r') as f: | ||
gt_diarization = json.load(f) | ||
gt_diarization = gt_diarization["diarization"] | ||
|
||
wav = preprocess_wav(args.audio_file) | ||
|
||
num_doctor_segments = len(args.doctor_segments) | ||
num_patient_segments = len(args.patient_segments) | ||
doctor_segments = [[float(time.split('-')[0]), float(time.split('-')[1])] | ||
for time in args.doctor_segments] | ||
patient_segments = [[float(time.split('-')[0]), float(time.split('-')[1])] | ||
for time in args.patient_segments] | ||
|
||
# Cut some segments from single speakers as reference audio | ||
segments = doctor_segments + patient_segments | ||
speaker_names = (["Doctor"] * num_doctor_segments) + (["Patient"] * num_patient_segments) | ||
speaker_wavs = [wav[int(s[0] * sampling_rate):int(s[1]) * sampling_rate] for s in segments] | ||
|
||
encoder = VoiceEncoder("cpu") | ||
|
||
def compute_diarization(partials_n_frames, speaker_embed_rate, audio_embed_rate): | ||
time_per_partial_frame = partials_n_frames * 10 / 1000 * 16000 | ||
resemblyzer.voice_encoder.partials_n_frames = partials_n_frames | ||
_, cont_embeds, wav_splits = encoder.embed_utterance(wav, return_partials=True, rate=audio_embed_rate) | ||
speaker_embeds = [encoder.embed_utterance(speaker_wav, rate=speaker_embed_rate) for speaker_wav in speaker_wavs] | ||
similarity_dict = {name: cont_embeds @ speaker_embed for name, speaker_embed in | ||
zip(speaker_names, speaker_embeds)} | ||
similarity_matrix = np.array([cont_embeds @ speaker_embed for name, speaker_embed in | ||
zip(speaker_names, speaker_embeds)]) | ||
speaker_predictions_indexes = np.argmax(similarity_matrix, axis=0) | ||
speaker_predictions = [speaker_names[index] for index in speaker_predictions_indexes] | ||
predictions = [] | ||
# Produce output JSON data | ||
for i in range(len(speaker_predictions)): | ||
seconds = i * (960 / 16000) | ||
seconds += time_per_partial_frame / 2 / 16000 # Centre prediction at middle of interval | ||
seconds = round(seconds, 2) | ||
speaker_prediction = speaker_predictions[i] | ||
predictions.append({ | ||
"time": seconds, | ||
"speaker": speaker_prediction | ||
}) | ||
return predictions | ||
|
||
|
||
def diarization_word_accuracy(pred_diarization, gt_diarization, transcript_elements): | ||
gt_diarization = gt_diarization[:len(pred_diarization)] | ||
num_correct = 0 | ||
pred_transcript_elements = diarize_transcript_elements(deepcopy(transcript_elements), pred_diarization) | ||
gt_transcript_elements = diarize_transcript_elements(deepcopy(transcript_elements), gt_diarization) | ||
for i in range(len(pred_transcript_elements)): | ||
if "ts" in pred_transcript_elements[i]: | ||
assert pred_transcript_elements[i]["ts"] == gt_transcript_elements[i]["ts"] | ||
if pred_transcript_elements[i]["speaker"] == gt_transcript_elements[i]["speaker"]: | ||
num_correct += 1 | ||
return num_correct / len(transcript_elements) | ||
|
||
|
||
audio_embed_rates = [8, 12, 16] | ||
speaker_embed_rates = [8, 12, 16] | ||
partials_n_frames_list = [115, 120, 125] | ||
|
||
max_accuracy = 0 | ||
|
||
for audio_embed_rate in audio_embed_rates: | ||
for speaker_embed_rate in speaker_embed_rates: | ||
for partials_n_frames in partials_n_frames_list: | ||
pred_diarization = compute_diarization(partials_n_frames, speaker_embed_rate, audio_embed_rate) | ||
accuracy = diarization_word_accuracy(pred_diarization, gt_diarization, transcript_elements) | ||
print("Accuracy: {}% | AER: {}, SER: {}, PNF: {}".format( | ||
round(accuracy * 100, 2), | ||
audio_embed_rate, | ||
speaker_embed_rate, | ||
partials_n_frames)) | ||
if accuracy > max_accuracy: | ||
print("^ New max accuracy!") | ||
max_accuracy = accuracy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters