diff --git a/docs/make.jl b/docs/make.jl index 6c7b483caa..9fc32e7632 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -40,6 +40,7 @@ makedocs( "OneHotArrays.jl" => "reference/data/onehot.md", "Low-level Operations -- NNlib.jl" => "reference/models/nnlib.md", "Nested Structures -- Functors.jl" => "reference/models/functors.md", + "Advanced" => "reference/misc-model-tweaking.md" ], "Tutorials" => [ # These walk you through various tasks. It's fine if they overlap quite a lot. diff --git a/docs/src/reference/misc-model-tweaking.md b/docs/src/reference/misc-model-tweaking.md new file mode 100644 index 0000000000..ed1bd72793 --- /dev/null +++ b/docs/src/reference/misc-model-tweaking.md @@ -0,0 +1,135 @@ +# Choosing differentiable/gpu parts of the model +!!! note + This tutorial features somewhat disconnected topics about customizing your + models even further. It is advised to be familiar with + [`Flux.@layer`](@ref), [`Flux.@functor`](@ref), [`freeze!`](@ref + Flux.freeze!) and other basics of Flux. + +Flux provides several ways of freezing, excluding from backprop entirely and +marking custom struct fields not to be moved to the GPU +([Functors.@functor](@ref)) hence excluded from being trained. The following +subsections should make it clear which one suits your needs the best. + +## On-the-fly freezing per model instance +Perhaps you'd like to freeze some of the weights of the model (even at +mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) +and `thaw!`. + +```julia +m = Chain( + Dense(784 => 64, relu), # freeze this one + Dense(64 => 64, relu), + Dense(32 => 10) + ) +opt_state = Flux.setup(Momentum(), m); + +# Freeze some layers right away +Flux.freeze!(opt_state.layers[1]) + +for data in train_set + input, label = data + + # Some params could be frozen during the training: + Flux.freeze!(opt_state.layers[2]) + + grads = Flux.gradient(m) do m + result = m(input) + loss(result, label) + end + Flux.update!(opt_state, m, grads[1]) + + # Optionally unfreeze the params later + Flux.thaw!(opt_state.layers[1]) +end +``` + +## Static freezing per model definition +Sometimes some parts of the model ([`Flux.@layer`](@ref)) needn't to be trained at all but these params +still need to reside on the GPU (these params are still needed in the forward +and/or backward pass). This is somewhat similar to `pytorch` [register_buffer](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer). + + + +```julia +struct MaskedLayer chain; mask; end + +Flux.@layer MyLayer trainable=(chain,) +# mask field will not be updated in the training loop + +function (m::MaskedLayer)(x) + # mask field will still move to to gpu for efficient operations: + return m.chain(x) + x + m.mask +end + +model = MaskedLayer(...) # this model will not have the `mask` field trained +``` +Note how this method permanently sets some model fields to be excluded from +training without on-the-fly changing. + +## Excluding from model definition +Sometimes some parameters aren't just "not trainable" but they shouldn't even +transfer to the GPU (or be part of the functor). All scalar fields are like this +by default, so things like learning rate multipliers are not trainable nor +transferred to the GPU by default. +```julia +struct CustomLayer chain; activation_results; end +Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results` + +function (m::CustomLayer)(x) + result = m.chain(x) + x + + # `activation_results` are not part of the GPU loop, hence we could do + # things like `push!` + push!(m.activation_results, mean(result)) + return result +end +``` +See more about this in [`Flux.@functor`](@ref) + + +## Freezing Layer Parameters (deprecated) +When it is desired to not include all the model parameters (for e.g. transfer +learning), we can simply not pass in those layers into our call to `params`. + +!!! compat "Flux ≤ 0.14" The mechanism described here is for Flux's old + "implicit" training style. When upgrading for Flux 0.15, it should be + replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. + +Consider a simple multi-layer perceptron model where we want to avoid optimising +the first two `Dense` layers. We can obtain this using the slicing features +`Chain` provides: + +```julia +m = Chain( + Dense(784 => 64, relu), + Dense(64 => 64, relu), + Dense(32 => 10) + ); + +ps = Flux.params(m[3:end]) +``` + +The `Zygote.Params` object `ps` now holds a reference to only the parameters of +the layers passed to it. + +During training, the gradients will only be computed for (and applied to) the +last `Dense` layer, therefore only that would have its parameters changed. + +`Flux.params` also takes multiple inputs to make it easy to collect parameters +from heterogenous models with a single call. A simple demonstration would be if +we wanted to omit optimising the second `Dense` layer in the previous example. +It would look something like this: + +```julia +Flux.params(m[1], m[3:end]) +``` + +Sometimes, a more fine-tuned control is needed. We can freeze a specific +parameter of a specific layer which already entered a `Params` object `ps`, by +simply deleting it from `ps`: + +```julia +ps = Flux.params(m) +delete!(ps, m[2].bias) +``` +