Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batchnorm in testmode without track stats #576

Merged
merged 1 commit into from
Apr 23, 2024

Conversation

paulnovo
Copy link
Contributor

@paulnovo paulnovo commented Apr 21, 2024

In test mode, the CUDA cuDNN implementation of batchnorm was not matching the CPU batchnorm in FLUX. In FLUX, with track_stats=False, the mean and variance of the current batch are used. Here, mean and variance were initialized to 0 and 1, respectively, and passed to cudnnBatchNormalizationForwardInference.

To fix this, we need to calculate the mean and variance over the current batch to match the CPU implementation. Unfortunately, cudnnBatchNormalizationForwardInference requires a trained running mean and variance. However, batchnorm train and test should be identical without tracked stats since they both normalize over the current batch. As a result we can use cudnnBatchNormalizationForwardTraining in test mode as well, which works without a running mean and variance.

This is needed to help address FluxML/Flux.jl#1606 along with Flux.jl PR 2427.

PR Checklist

  • Tests are added
  • Documentation, if applicable

In test mode, the CUDA cuDNN implementation of batchnorm was not
matching the CPU batchnorm in FLUX. In FLUX, with track_stats=False, the
mean and variance of the current batch are used. Here, mean and variance
were initialized to 0 and 1, respectively, and passed to
cudnnBatchNormalizationForwardInference.

To fix this, we need to calculate the mean and variance over the current
batch to match the CPU implementation. Unfortunately,
cudnnBatchNormalizationForwardInference requires a trained running mean
and variance. However, batchnorm train and test should be identical
without tracked stats since they both normalize over the current
batch. As a result we can use cudnnBatchNormalizationForwardTraining in
test mode as well, which works without a running mean and variance.
@CarloLucibello CarloLucibello merged commit e8e7572 into FluxML:master Apr 23, 2024
11 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants