You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First posted this issue on faster-whisper but they suggested that I should post it here since it is likely an issue with ctranslate2
Description
When using custom fine-tuned models, faster-whisper's implementation shows significant WER degradation compared to OpenAI's reference implementation (13.5% vs 8.2% WER). Through investigation, I've traced this to numerical differences starting from the very first Conv1D operation in the encoder and is also present for the original large-v3 weights. The weights are both stored in float16, suggesting some numerical or algorithmic issue.
Here is some benchmark results on custom dataset.
Implementation
Model
Performance (WER)
openai
large-v3
0.124165
faster-whisper
large-v3
0.0992092
openai
custom
0.0819845
faster-whisper
custom
0.135567
After comparing the logits of the two implementations and trying to narrow down the root cause of the issue i manage to locate that the difference starts as early the first conv1d operation in the encoder.
Steps to reproduce
Modifying the whisper encoder operator function to only apply the conv1d operation. I.e:
voidWhisperEncoder::operator()(const StorageView& features, StorageView& output) {
PROFILE("WhisperEncoder");
if (features.rank() != 3)
throwstd::invalid_argument("Expected input features to have 3 dimensions, but got "
+ std::to_string(features.rank())
+ " dimension(s) instead");
if (features.dim(1) != input_size() || features.dim(2) > max_input_time())
throwstd::invalid_argument("Invalid input features shape: expected an input with shape ("
+ std::to_string(features.dim(0))
+ ", "
+ std::to_string(input_size())
+ ", "
+ std::to_string(std::min(features.dim(2), max_input_time()))
+ "), but got an input with shape ("
+ std::to_string(features.dim(0))
+ ", "
+ std::to_string(features.dim(1))
+ ", "
+ std::to_string(features.dim(2))
+ ") instead");
StorageView input(output_type(), features.device());
_conv1(features, output);
}
And running the following scripts will show the diff between the implementation.
I am also fighting accuracy problems after conversion and I have been doing some experiments as well but I get exactly the same numbers for the convolutions (I modified my ctranslate's whisper implementation as well and i am adding back specific operations). I believe you should explicitly specify the parameters for the convolutions in your torch experiments:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
audio = torch.ones((1, 128, 3000), dtype=torch.bfloat16, device='cuda:0', requires_grad=False)
weights = torch.load('/home/gpu/.cache/whisper/large-v3.pt')
weight1 = weights['model_state_dict']['encoder.conv1.weight'].to(torch.bfloat16).to('cuda:0')
bias1 = weights['model_state_dict']['encoder.conv1.bias'].to(torch.bfloat16).to('cuda:0')
weight2 = weights['model_state_dict']['encoder.conv2.weight'].to(torch.bfloat16).to('cuda:0')
bias2 = weights['model_state_dict']['encoder.conv2.bias'].to(torch.bfloat16).to('cuda:0')
# out = F.conv1d(audio, weight, bias, padding=1)
# out = F.gelu(out)
# out = F.conv1d(out, weight2, bias2, padding=1)
conv1 = nn.Conv1d(
in_channels=128,
out_channels=1280,
kernel_size=3,
stride=1,
padding=1,
bias=True
).to(audio.device, dtype=torch.bfloat16)
conv2 = nn.Conv1d(
in_channels=1280,
out_channels=1280,
kernel_size=3,
stride=2,
padding=1,
bias=True
).to(audio.device, dtype=torch.bfloat16)
# Copy the weights and biases into the conv layers
with torch.no_grad():
conv1.weight.copy_(weight1)
conv1.bias.copy_(bias1)
conv2.weight.copy_(weight2)
conv2.bias.copy_(bias2)
# Forward pass through conv1 -> GELU -> conv2 -> GELU
x = conv1(audio)
x = F.gelu(x)
y = x.detach().cpu().to(torch.float32).numpy()
for i in range(5):
print(y[0, 0, -5 + i])
print("Output shape after conv1:", x.shape)
x = conv2(x)
x = F.gelu(x)
print("Output shape after conv1+conv2:", x.shape)
enc_numpy_2 = x.detach().cpu().to(torch.float32).numpy()
#enc_numpy_2 = np.transpose(enc_numpy_2, (0, 2, 1))
print("Torch: ", enc_numpy_2.shape)
for i in range(5):
print(enc_numpy_2[0, 0, -5 + i])
#save enc_numpy_2 to a file
np.save('enc_numpy_torch.npy', enc_numpy_2)
First posted this issue on faster-whisper but they suggested that I should post it here since it is likely an issue with ctranslate2
Description
When using custom fine-tuned models, faster-whisper's implementation shows significant WER degradation compared to OpenAI's reference implementation (13.5% vs 8.2% WER). Through investigation, I've traced this to numerical differences starting from the very first Conv1D operation in the encoder and is also present for the original large-v3 weights. The weights are both stored in float16, suggesting some numerical or algorithmic issue.
Here is some benchmark results on custom dataset.
After comparing the logits of the two implementations and trying to narrow down the root cause of the issue i manage to locate that the difference starts as early the first conv1d operation in the encoder.
Steps to reproduce
Modifying the whisper encoder operator function to only apply the conv1d operation. I.e:
And running the following scripts will show the diff between the implementation.
Faster whisper
pytorch equivalent is
If i run these two scripts i get the following outputs:
faster-whisper:
pytorch
Environment
python 3.10.10
CUDA Version: 12.3
All running inside a docker container: nvcr.io/nvidia/pytorch:23.10-py3
GPU: 3090
Additional findings
Precision behavior:
Input dependency (bfloat16):
The text was updated successfully, but these errors were encountered: