-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve type stability of LayerNorm and Dropout #2005
base: master
Are you sure you want to change the base?
Conversation
25f0a1b
to
9259e4a
Compare
TTFG timings using the following snippet: Test codeusing Metalhead, Flux, Zygote
using Metalhead: ChannelLayerNorm
model = ConvNeXt(:tiny; inchannels=1, nclasses=1).layers
# ChannelLayerNorm isn't type stable yet (for the same reason as LayerNorm wasn't),
# So remove it for this demo
model = fmap(Returns(identity), model; exclude=Base.Fix2(isa, ChannelLayerNorm))
# display(model); println()
loss(m, x) = sum(m(x))
inputs = randn(Float32, 32, 32, 1, 1)
# @time loss(model, inputs)
# @time loss(model, inputs)
loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)
@time loss_grad(model, inputs)
# @time loss_grad(model, inputs)
Replacing the
|
For kicks, here is Diffractor with JuliaDiff/ChainRules.jl#644: julia> @time loss_grad(model, inputs)
30.442982 seconds (92.61 M allocations: 4.148 GiB, 3.18% gc time, 89.07% compilation time) # tuple chain
23.051121 seconds (88.06 M allocations: 3.920 GiB, 3.81% gc time, 85.11% compilation time) # vector chain, requires https://github.com/JuliaDiff/Diffractor.jl/pull/82 Re-enabling Edit: added times for vector chains using a patched Diffractor. |
Does Diffractor already work with most Flux models (or at least those with built-in layers)? I was under the impression that it wasn't there yet 😅 |
Not OOTB, which is why that ChainRules PR is required. |
@ToucheSir Could you try running the layer norm gradient with gpu? I have try that manual broadcast fusion before but |
You're right, it allocates one more time for over 2x the memory overhead. I also found this out the hard way recently while trying to fuse the RNN cell kernels for #2023, but forgot about the change here. |
9259e4a
to
29ef2ff
Compare
Codecov Report
@@ Coverage Diff @@
## master #2005 +/- ##
==========================================
+ Coverage 87.10% 87.37% +0.27%
==========================================
Files 20 20
Lines 1528 1553 +25
==========================================
+ Hits 1331 1357 +26
+ Misses 197 196 -1
Continue to review full report at Codecov.
|
Any updates on this (like benchmarks after unfusing)? |
These two layers made use of explicit or implicit control flow (e.g. default keyword argument values) which Zygote does not like. This PR is essentially a set of small hacks to work around that.
Any ideas on how to avoid
return_type
in_dropout
would be much appreciated, but for now it seems to work.TODO benchmarks.
PR Checklist