diff --git a/examples/mms/data_prep/align_and_segment.py b/examples/mms/data_prep/align_and_segment.py index 732f93c3fb..dcaba4f5a2 100644 --- a/examples/mms/data_prep/align_and_segment.py +++ b/examples/mms/data_prep/align_and_segment.py @@ -51,12 +51,12 @@ def generate_emissions(model, audio_file): offset = time_to_frame(input_start_time) emissions_ = emissions_[ - :, emission_start_frame - offset : emission_end_frame - offset + emission_start_frame - offset : emission_end_frame - offset, : ] emissions_arr.append(emissions_) i += EMISSION_INTERVAL - emissions = torch.cat(emissions_arr, dim=1).squeeze() + emissions = torch.cat(emissions_arr, dim=0).squeeze() emissions = torch.log_softmax(emissions, dim=-1) stride = float(waveform.size(1) * 1000 / emissions.size(0) / SAMPLING_FREQ)