From 0a679c1f4ac48d97e5ac1641f7b4ed182486ea78 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Tue, 26 Nov 2024 08:04:20 +0100 Subject: [PATCH] noexpand --- src/layers/attention.jl | 2 +- src/layers/macro.jl | 34 +++++++++++++++++++++++----------- src/layers/normalise.jl | 2 +- src/layers/show.jl | 17 +++++++++++++++++ 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index d4a33283d9..50c023d7ca 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2} out_proj::P2 end -@layer MultiHeadAttention +@layer :noexpand MultiHeadAttention function MultiHeadAttention(dims; nheads::Int = 8, diff --git a/src/layers/macro.jl b/src/layers/macro.jl index eee3072eb0..b76f41570a 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -1,20 +1,24 @@ """ - @layer MyModel - @layer MyModel trainable=(β,γ) - + @layer [showtype] MyModel [trainable=(field1,...)] + This macro adds convenience functionality to a custom type to serve as a neural network layer, as a module, or as an entire model. -The keyword `trainable` allows you to specify which fiels of you model can be trained, +The optional keyword `trainable` allows you to specify which fields of your model can be trained, instead of assuming all `fieldnames(MyModel)` to trainable. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. -This can be also done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type. +This can be also be done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type. + +The macro also handles overloads of the 3-arg `show(::IO, ::MIME"text/plain", ::MyModel)` for pretty printing. +The optional argument `showtype` can take any of the following values: -The macro also handles overloads of `show` for pretty printing. -It adds methods to `show(::IO, ::MIME"text/plain", ::MyModel)` to treat your layer much like `Dense` or `Chain`. -To opt out of this, use `@layer :ignore MyModel`. -In case, you probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this. +- `:expand` (default): This will expand the representation of container types like `Chain`, + while maintaining a compat representation of types like `Dense` containing only arrays. +- `:noexpand`: This is to be used in case your type contains other layers but you want to keep the representation simple. +- `:ignore`: To opt out of the pretty printing. + +You probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this. Note that re-running the macro with different options may not remove all methods, you will need to restart. @@ -33,6 +37,12 @@ Trio( Dense(1 => 1; bias=false), # 1 parameters Dropout(0.4), ) # Total: 3 arrays, 4 parameters, 240 bytes. + +# Freeze `c`, equivalent to `Optimisers.trainable(tri::Trio) = (; tri.a, tri.b)` +julia> Flux.@layer Trio trainable=(a,b) + +# Now the optimizer's state won't contain `c` +julia> opt_state = Flux.setup(Adam(), tri); ``` """ @@ -45,8 +55,10 @@ function _layer_macro(exs...) # These functions are defined in show.jl, and each return an expression overloading Base.show type, rest... = if exs[1] == QuoteNode(:expand) - @warn "The `:expand` option is deprecated, and will be removed in a future release. Use `@layer` without options instead." maxlog=1 - push!(out.args, _macro_big_show(esc(exs[1]))) + push!(out.args, _macro_big_show(esc(exs[2]))) + exs[2:end] + elseif exs[1] == QuoteNode(:noexpand) + push!(out.args, _macro_layer_show(esc(exs[2]))) exs[2:end] elseif exs[1] == QuoteNode(:ignore) exs[2:end] diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9d294e3e6e..dded9ab306 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -198,7 +198,7 @@ end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) -@layer LayerNorm +@layer :noexpand LayerNorm function (a::LayerNorm)(x::AbstractArray) ChainRulesCore.@ignore_derivatives if a.diag isa Scale diff --git a/src/layers/show.jl b/src/layers/show.jl index 06cd93f314..8b5aa716b5 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -83,6 +83,23 @@ function _flat_children(x) gamma = ((beta...)...,) end +# This is called by @layer :noexpand, on layers which should be treated like Dense, and returns an expression: +function _macro_layer_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end + end + + # Exit from _big_show recursion: + Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name) + end +end + function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " str = _str * _layer_string(io, layer)