Skip to content

Commit

Permalink
Add diarization scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
wkevwang committed Oct 11, 2020
1 parent 30b90bd commit cfec2b8
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 83 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ __pycache__/
.pytest_cache/
.vscode/

audio/
data/
snomed_ct/*.txt
transcriptions/
qa_summary/

.DS_Store
*.docx
*.json
175 changes: 114 additions & 61 deletions diarization.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,118 @@
from resemblyzer import preprocess_wav, VoiceEncoder
"""
Diarize audio using Resemblyzer (https://github.com/resemble-ai/Resemblyzer)
This script is adapted from demo 2 in the Resemblyzer repo.
Specify the audio segments for doctor and patient with start and end time,
connected with a dash. Add space between multiple audio segments.
E.g. --doctor_segments 3.2-5.5 10-12.1
Other arguments
* partials_n_frames: The number of frames used per prediction
* audio_embed_rate: How many times per second to make predictions
* speaker_embed_rate: How many times per second to embed the audio
segments for each speaker
Every <frame_step> # of frames, partials_n_frames # of frames are
encoded to make a prediction. The prediction seems to be most associated
with the centre of that time interval.
"""

import argparse
import json
import os
import numpy as np

import resemblyzer
from resemblyzer import preprocess_wav, VoiceEncoder, sampling_rate
from resemblyzer.hparams import mel_window_step


def secs_per_partial(args):
"""
Samples per partial (n frames encoded for a single prediction)
in seconds
"""
return args.partials_n_frames * mel_window_step / 1000


def samples_between_preidctions(args):
"""
# of frames per partial
"""
return int(np.round((sampling_rate / rate) / samples_per_frame))


def print_predictions(speaker_predictions, wav_splits, args, freq=2):
interval = int(args.audio_embed_rate / freq)
for i in range(0, len(speaker_predictions), interval):
midpoint_offset = (wav_splits[i].stop - wav_splits[i].start) / sampling_rate / 2
seconds = (wav_splits[i].start / sampling_rate) + midpoint_offset
print("{}m:{}s | Speaker: {} | Doctor Conf: {}".format(
int(seconds / 60), round(seconds % 60, 1),
speaker_predictions[i], similarity_matrix[0, i]))


def output_json(speaker_predictions, wav_splits, args):
diarization = []
for i in range(len(speaker_predictions)):
midpoint_offset = (wav_splits[i].stop - wav_splits[i].start) / sampling_rate / 2
seconds = (wav_splits[i].start / sampling_rate) + midpoint_offset
speaker_prediction = speaker_predictions[i]
diarization.append({
"time": round(seconds, 2),
"speaker": speaker_prediction
})

filename_prefix = os.path.splitext(os.path.basename(args.audio_file))[0]
json_filename = filename_prefix + '-diarization.json'
with open(json_filename, 'w') as f:
json.dump({"diarization": diarization}, f, indent=4)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--audio_file", type=str, required=True)
parser.add_argument("--doctor_segments", type=str, nargs='+', required=True,)
parser.add_argument("--patient_segments", type=str, nargs='+', required=True)
parser.add_argument("--partials_n_frames", type=int, default=160, required=False)
parser.add_argument("--audio_embed_rate", type=int, default=16, required=False)
parser.add_argument("--speaker_embed_rate", type=int, default=16, required=False)
parser.add_argument("--model_file", type=str, default="pretrained.pt", required=False)
args = parser.parse_args()

# Set partials_n_frames (number of frames region per prediction)
resemblyzer.voice_encoder.partials_n_frames = args.partials_n_frames
wav = preprocess_wav(args.audio_file)

# Cut some segments from each speaker as reference audio
num_doctor_segments = len(args.doctor_segments)
num_patient_segments = len(args.patient_segments)
doctor_segments = [[float(time.split('-')[0]), float(time.split('-')[1])]
for time in args.doctor_segments]
patient_segments = [[float(time.split('-')[0]), float(time.split('-')[1])]
for time in args.patient_segments]
segments = doctor_segments + patient_segments
speaker_names = (["Doctor"] * num_doctor_segments) + (["Patient"] * num_patient_segments)
speaker_wavs = [wav[int(s[0] * sampling_rate):int(s[1]) * sampling_rate] for s in segments]

# Encode the audio
print("Encode the audio...")
encoder = VoiceEncoder("cpu", model_file=args.model_file)
_, cont_embeds, wav_splits = encoder.embed_utterance(wav, return_partials=True, rate=args.audio_embed_rate)
speaker_embeds = [encoder.embed_utterance(speaker_wav, rate=args.speaker_embed_rate) for speaker_wav in speaker_wavs]

parser = argparse.ArgumentParser()
parser.add_argument("--audio_file", type=str, required=True)
parser.add_argument("--doctor_segments", type=str, nargs='+', required=True)
parser.add_argument("--patient_segments", type=str, nargs='+', required=True)
parser.add_argument("--partials_n_frames", default=160, type=int, required=False)
args = parser.parse_args()

time_per_partial_frame = args.partials_n_frames * 10 / 1000 * 16000
wav = preprocess_wav(args.audio_file)

num_doctor_segments = len(args.doctor_segments)
num_patient_segments = len(args.patient_segments)
doctor_segments = [[float(time.split('-')[0]), float(time.split('-')[1])]
for time in args.doctor_segments]
patient_segments = [[float(time.split('-')[0]), float(time.split('-')[1])]
for time in args.patient_segments]

# Cut some segments from single speakers as reference audio
segments = doctor_segments + patient_segments
speaker_names = (["Doctor"] * num_doctor_segments) + (["Patient"] * num_patient_segments)
speaker_wavs = [wav[int(s[0] * sampling_rate):int(s[1]) * sampling_rate] for s in segments]

rate = 16
encoder = VoiceEncoder("cpu")
print("Running the continuous embedding on cpu, this might take a while...")
_, cont_embeds, wav_splits = encoder.embed_utterance(wav, return_partials=True, rate=rate)

speaker_embeds = [encoder.embed_utterance(speaker_wav) for speaker_wav in speaker_wavs]
similarity_dict = {name: cont_embeds @ speaker_embed for name, speaker_embed in
zip(speaker_names, speaker_embeds)}
similarity_matrix = np.array([cont_embeds @ speaker_embed for name, speaker_embed in
zip(speaker_names, speaker_embeds)])

speaker_predictions_indexes = np.argmax(similarity_matrix, axis=0)
speaker_predictions = [speaker_names[index] for index in speaker_predictions_indexes]
times_per_sec = 2
interval = int(rate / times_per_sec)
for i in range(0, len(speaker_predictions), interval):
seconds = i * (960 / 16000)
seconds += time_per_partial_frame / 2 / 16000 # Centre prediction at middle of interval
print("{}m:{}s | Speaker: {} | Doctor Conf: {}".format(int(seconds / 60), round(seconds % 60, 1), speaker_predictions[i], similarity_matrix[0, i]))

output_json = {
"diarization": []
}
# Produce output JSON data
for i in range(len(speaker_predictions)):
seconds = i * (960 / 16000)
seconds += time_per_partial_frame / 2 / 16000 # Centre prediction at middle of interval
seconds = round(seconds, 2)
speaker_prediction = speaker_predictions[i]
output_json["diarization"].append({
"time": seconds,
"speaker": speaker_prediction
})

filename_prefix = os.path.splitext(os.path.basename(args.audio_file))[0]
json_filename = filename_prefix + '-diarization.json'
with open(json_filename, 'w') as f:
json.dump(output_json, f, indent=4)
# Determine who spoke when
similarity_dict = {name: cont_embeds @ speaker_embed for name, speaker_embed in
zip(speaker_names, speaker_embeds)}
similarity_matrix = np.array([cont_embeds @ speaker_embed for name, speaker_embed in
zip(speaker_names, speaker_embeds)])
speaker_predictions_indexes = np.argmax(similarity_matrix, axis=0)
speaker_predictions = [speaker_names[index] for index in speaker_predictions_indexes]

# Print predictions
print_predictions(speaker_predictions, wav_splits, args)

# Produce output JSON data
output_json(speaker_predictions, wav_splits, args)
99 changes: 99 additions & 0 deletions diarization_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import argparse
import json
import os
import numpy as np
from copy import deepcopy

import resemblyzer
from resemblyzer import preprocess_wav, VoiceEncoder, sampling_rate
from rev_diarization import diarize_transcript_elements, parse_transcript_elements

parser = argparse.ArgumentParser()
parser.add_argument("--audio_file", type=str, required=True)
parser.add_argument("--doctor_segments", type=str, nargs='+', required=True)
parser.add_argument("--patient_segments", type=str, nargs='+', required=True)
parser.add_argument("--transcript_file", type=str, required=True)
parser.add_argument("--gt_diarization_file", type=str, required=True)
args = parser.parse_args()

with open(args.transcript_file, 'r') as f:
transcript_file = json.load(f)
transcript_elements = parse_transcript_elements(transcript_file)

with open(args.gt_diarization_file, 'r') as f:
gt_diarization = json.load(f)
gt_diarization = gt_diarization["diarization"]

wav = preprocess_wav(args.audio_file)

num_doctor_segments = len(args.doctor_segments)
num_patient_segments = len(args.patient_segments)
doctor_segments = [[float(time.split('-')[0]), float(time.split('-')[1])]
for time in args.doctor_segments]
patient_segments = [[float(time.split('-')[0]), float(time.split('-')[1])]
for time in args.patient_segments]

# Cut some segments from single speakers as reference audio
segments = doctor_segments + patient_segments
speaker_names = (["Doctor"] * num_doctor_segments) + (["Patient"] * num_patient_segments)
speaker_wavs = [wav[int(s[0] * sampling_rate):int(s[1]) * sampling_rate] for s in segments]

encoder = VoiceEncoder("cpu")

def compute_diarization(partials_n_frames, speaker_embed_rate, audio_embed_rate):
time_per_partial_frame = partials_n_frames * 10 / 1000 * 16000
resemblyzer.voice_encoder.partials_n_frames = partials_n_frames
_, cont_embeds, wav_splits = encoder.embed_utterance(wav, return_partials=True, rate=audio_embed_rate)
speaker_embeds = [encoder.embed_utterance(speaker_wav, rate=speaker_embed_rate) for speaker_wav in speaker_wavs]
similarity_dict = {name: cont_embeds @ speaker_embed for name, speaker_embed in
zip(speaker_names, speaker_embeds)}
similarity_matrix = np.array([cont_embeds @ speaker_embed for name, speaker_embed in
zip(speaker_names, speaker_embeds)])
speaker_predictions_indexes = np.argmax(similarity_matrix, axis=0)
speaker_predictions = [speaker_names[index] for index in speaker_predictions_indexes]
predictions = []
# Produce output JSON data
for i in range(len(speaker_predictions)):
seconds = i * (960 / 16000)
seconds += time_per_partial_frame / 2 / 16000 # Centre prediction at middle of interval
seconds = round(seconds, 2)
speaker_prediction = speaker_predictions[i]
predictions.append({
"time": seconds,
"speaker": speaker_prediction
})
return predictions


def diarization_word_accuracy(pred_diarization, gt_diarization, transcript_elements):
gt_diarization = gt_diarization[:len(pred_diarization)]
num_correct = 0
pred_transcript_elements = diarize_transcript_elements(deepcopy(transcript_elements), pred_diarization)
gt_transcript_elements = diarize_transcript_elements(deepcopy(transcript_elements), gt_diarization)
for i in range(len(pred_transcript_elements)):
if "ts" in pred_transcript_elements[i]:
assert pred_transcript_elements[i]["ts"] == gt_transcript_elements[i]["ts"]
if pred_transcript_elements[i]["speaker"] == gt_transcript_elements[i]["speaker"]:
num_correct += 1
return num_correct / len(transcript_elements)


audio_embed_rates = [8, 12, 16]
speaker_embed_rates = [8, 12, 16]
partials_n_frames_list = [115, 120, 125]

max_accuracy = 0

for audio_embed_rate in audio_embed_rates:
for speaker_embed_rate in speaker_embed_rates:
for partials_n_frames in partials_n_frames_list:
pred_diarization = compute_diarization(partials_n_frames, speaker_embed_rate, audio_embed_rate)
accuracy = diarization_word_accuracy(pred_diarization, gt_diarization, transcript_elements)
print("Accuracy: {}% | AER: {}, SER: {}, PNF: {}".format(
round(accuracy * 100, 2),
audio_embed_rate,
speaker_embed_rate,
partials_n_frames))
if accuracy > max_accuracy:
print("^ New max accuracy!")
max_accuracy = accuracy
6 changes: 3 additions & 3 deletions resemblyzer/voice_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class VoiceEncoder(nn.Module):
def __init__(self, device: Union[str, torch.device]=None, verbose=True):
def __init__(self, device: Union[str, torch.device]=None, verbose=True, model_file="pretrained.pt"):
"""
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
If None, defaults to cuda if it is available on your machine, otherwise the model will
Expand All @@ -30,7 +30,7 @@ def __init__(self, device: Union[str, torch.device]=None, verbose=True):
self.device = device

# Load the pretrained model'speaker weights
weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
weights_fpath = Path(__file__).resolve().parent.joinpath(model_file)
if not weights_fpath.exists():
raise Exception("Couldn't find the voice encoder pretrained model at %s." %
weights_fpath)
Expand Down Expand Up @@ -92,7 +92,7 @@ def compute_partial_slices(n_samples: int, rate, min_coverage):
assert 0 < frame_step, "The rate is too high"
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
(sampling_rate / (samples_per_frame * partials_n_frames))

# Compute the slices
wav_slices, mel_slices = [], []
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
Expand Down
51 changes: 32 additions & 19 deletions rev_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,7 @@ def all_dict_values_same(d):
return (len(set(d.values())) == 1)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--transcript_file", type=str, required=True)
parser.add_argument("--diarization_file", type=str, required=True)
parser.add_argument("--output_folder", type=str, required=True)
parser.add_argument("--diarization_offset", required=False, type=float, default=0.0)
args = parser.parse_args()

with open(args.transcript_file, 'r') as f:
transcript_file = json.load(f)

with open(args.diarization_file, 'r') as f:
diarization_file = json.load(f)

transcript = []

def parse_transcript_elements(transcript_file):
transcript_monologues = transcript_file["monologues"]
transcript_elements_lists = [m["elements"] for m in transcript_monologues]
transcript_elements = []
Expand All @@ -41,7 +26,10 @@ def all_dict_values_same(d):
if element['value'] == ' ':
continue
transcript_elements.append(element)

return transcript_elements


def diarize_transcript_elements(transcript_elements, diarization, diarization_offset=0.0):
# Add speaker info to elements
last_speaker = 'Doctor'
for idx, element in enumerate(transcript_elements):
Expand All @@ -52,8 +40,8 @@ def all_dict_values_same(d):
start_time = float(element["ts"])
end_time = float(element["end_ts"])
speaker_counts = {}
for prediction in diarization_file["diarization"]:
prediction_time = prediction["time"] + args.diarization_offset
for prediction in diarization:
prediction_time = prediction["time"] + diarization_offset
if prediction_time > end_time:
break
if start_time <= prediction_time <= end_time:
Expand All @@ -70,6 +58,31 @@ def all_dict_values_same(d):
speaker = max(speaker_counts, key=speaker_counts.get)
element["speaker"] = speaker
last_speaker = speaker
return transcript_elements


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--transcript_file", type=str, required=True)
parser.add_argument("--diarization_file", type=str, required=True)
parser.add_argument("--output_folder", type=str, required=True)
parser.add_argument("--diarization_offset", required=False, type=float, default=0.0)
args = parser.parse_args()

with open(args.transcript_file, 'r') as f:
transcript_file = json.load(f)

with open(args.diarization_file, 'r') as f:
diarization_file = json.load(f)

transcript = []

transcript_elements = parse_transcript_elements(transcript_file)

diarize_transcript_elements(
transcript_elements,
diarization_file["diarization"],
args.diarization_offset)

# Group speakers elements into transcript
last_speaker = None
Expand Down

0 comments on commit cfec2b8

Please sign in to comment.