diff --git a/examples/mms/data_prep/align_and_segment.py b/examples/mms/data_prep/align_and_segment.py index dcaba4f5a2..cd5045eabc 100644 --- a/examples/mms/data_prep/align_and_segment.py +++ b/examples/mms/data_prep/align_and_segment.py @@ -85,9 +85,13 @@ def get_alignments( token_indices = [] blank = dictionary[""] + + targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE) + input_lengths = torch.tensor(emissions.shape[0]) + target_lengths = torch.tensor(targets.shape[0]) - path, _ = F.force_align( - emissions, torch.Tensor(token_indices, device=DEVICE).int(), blank=blank + path, _ = F.forced_align( + emissions, targets, input_lengths, target_lengths, blank=blank ) path = path.to("cpu").tolist() segments = merge_repeats(path, {v: k for k, v in dictionary.items()})