Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical discrepancy in reduced precision operations causes WER degradation for custom models #1845

Open
JacobAndersson opened this issue Jan 14, 2025 · 1 comment

Comments

@JacobAndersson
Copy link

First posted this issue on faster-whisper but they suggested that I should post it here since it is likely an issue with ctranslate2

Description

When using custom fine-tuned models, faster-whisper's implementation shows significant WER degradation compared to OpenAI's reference implementation (13.5% vs 8.2% WER). Through investigation, I've traced this to numerical differences starting from the very first Conv1D operation in the encoder and is also present for the original large-v3 weights. The weights are both stored in float16, suggesting some numerical or algorithmic issue.

Here is some benchmark results on custom dataset.

Implementation Model Performance (WER)
openai large-v3 0.124165
faster-whisper large-v3 0.0992092
openai custom 0.0819845
faster-whisper custom 0.135567

After comparing the logits of the two implementations and trying to narrow down the root cause of the issue i manage to locate that the difference starts as early the first conv1d operation in the encoder.

Steps to reproduce

Modifying the whisper encoder operator function to only apply the conv1d operation. I.e:

void WhisperEncoder::operator()(const StorageView& features, StorageView& output) {
      PROFILE("WhisperEncoder");

      if (features.rank() != 3)
        throw std::invalid_argument("Expected input features to have 3 dimensions, but got "
                                    + std::to_string(features.rank())
                                    + " dimension(s) instead");

      if (features.dim(1) != input_size() || features.dim(2) > max_input_time())
        throw std::invalid_argument("Invalid input features shape: expected an input with shape ("
                                    + std::to_string(features.dim(0))
                                    + ", "
                                    + std::to_string(input_size())
                                    + ", "
                                    + std::to_string(std::min(features.dim(2), max_input_time()))
                                    + "), but got an input with shape ("
                                    + std::to_string(features.dim(0))
                                    + ", "
                                    + std::to_string(features.dim(1))
                                    + ", "
                                    + std::to_string(features.dim(2))
                                    + ") instead");

      StorageView input(output_type(), features.device());

      _conv1(features, output);
    }

And running the following scripts will show the diff between the implementation.

Faster whisper

from faster_whisper import WhisperModel
import numpy as np
import ctranslate2

model = WhisperModel('large-v3', device="cuda", cpu_threads=32, num_workers=1, compute_type="bfloat16")

audio = np.ones((1, 128, 3000), dtype=np.float32)
enc = model.encode(audio)
enc_numpy = np.array(enc.to_device(ctranslate2.Device.cpu).to(ctranslate2.DataType.float32))

for i in range(5):
    print(enc_numpy[0, 0, -5 + i])

pytorch equivalent is

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn

audio = torch.ones((1, 128, 3000), dtype=torch.bfloat16, device='cuda:0', requires_grad=False)

weights = torch.load('./large-v3.pt')
weight = weights['model_state_dict']['encoder.conv1.weight'].to(torch.bfloat16).to('cuda:0')
bias = weights['model_state_dict']['encoder.conv1.bias'].to(torch.bfloat16).to('cuda:0')

out = F.conv1d(audio, weight, bias, padding=1)

enc_numpy_2 = out.detach().cpu().to(torch.float32).numpy()

for i in range(5):
    print(enc_numpy_2[0, 0, -5 + i])

If i run these two scripts i get the following outputs:

faster-whisper:

0.029418945
0.029418945
0.029418945
0.029418945
-0.003112793

pytorch

0.029296875
0.029296875
0.029296875
0.029296875
-0.0029296875

Environment

python 3.10.10
CUDA Version: 12.3
All running inside a docker container: nvcr.io/nvidia/pytorch:23.10-py3
GPU: 3090

Additional findings

Precision behavior:

  • float32: Values match but performance still degrades when i run the full benchmark
  • bfloat16: Values differ
  • float16: Values differ

Input dependency (bfloat16):

  • Zero inputs: Match perfectly
  • Unit inputs (1.0): Small difference
  • Larger inputs (2.0): Larger differences
@vakkov
Copy link
Contributor

vakkov commented Jan 15, 2025

I am also fighting accuracy problems after conversion and I have been doing some experiments as well but I get exactly the same numbers for the convolutions (I modified my ctranslate's whisper implementation as well and i am adding back specific operations). I believe you should explicitly specify the parameters for the convolutions in your torch experiments:


import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn

audio = torch.ones((1, 128, 3000), dtype=torch.bfloat16, device='cuda:0', requires_grad=False)

weights = torch.load('/home/gpu/.cache/whisper/large-v3.pt')
weight1 = weights['model_state_dict']['encoder.conv1.weight'].to(torch.bfloat16).to('cuda:0')
bias1   = weights['model_state_dict']['encoder.conv1.bias'].to(torch.bfloat16).to('cuda:0')

weight2 = weights['model_state_dict']['encoder.conv2.weight'].to(torch.bfloat16).to('cuda:0')
bias2 = weights['model_state_dict']['encoder.conv2.bias'].to(torch.bfloat16).to('cuda:0')

# out = F.conv1d(audio, weight, bias, padding=1)
# out = F.gelu(out)
# out = F.conv1d(out, weight2, bias2, padding=1)


conv1 = nn.Conv1d(
    in_channels=128,
    out_channels=1280,
    kernel_size=3,
    stride=1,
    padding=1,
    bias=True
).to(audio.device, dtype=torch.bfloat16)


conv2 = nn.Conv1d(
    in_channels=1280,
    out_channels=1280,
    kernel_size=3,
    stride=2,
    padding=1,
    bias=True
).to(audio.device, dtype=torch.bfloat16)

# Copy the weights and biases into the conv layers
with torch.no_grad():
    conv1.weight.copy_(weight1)
    conv1.bias.copy_(bias1)
    conv2.weight.copy_(weight2)
    conv2.bias.copy_(bias2)

# Forward pass through conv1 -> GELU -> conv2 -> GELU
x = conv1(audio)
x = F.gelu(x)


y = x.detach().cpu().to(torch.float32).numpy()
for i in range(5):
    print(y[0, 0, -5 + i])

print("Output shape after conv1:", x.shape)

x = conv2(x)
x = F.gelu(x)

print("Output shape after conv1+conv2:", x.shape)

enc_numpy_2 = x.detach().cpu().to(torch.float32).numpy()

#enc_numpy_2 = np.transpose(enc_numpy_2, (0, 2, 1))

print("Torch: ", enc_numpy_2.shape)


for i in range(5):
    print(enc_numpy_2[0, 0, -5 + i])

#save enc_numpy_2 to a file
np.save('enc_numpy_torch.npy', enc_numpy_2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants