Skip to content

Commit

Permalink
noexpand
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 26, 2024
1 parent a0b083a commit 0a679c1
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2}
out_proj::P2
end

@layer MultiHeadAttention
@layer :noexpand MultiHeadAttention

function MultiHeadAttention(dims;
nheads::Int = 8,
Expand Down
34 changes: 23 additions & 11 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@

"""
@layer MyModel
@layer MyModel trainable=(β,γ)
@layer [showtype] MyModel [trainable=(field1,...)]
This macro adds convenience functionality to a custom type to serve
as a neural network layer, as a module, or as an entire model.
The keyword `trainable` allows you to specify which fiels of you model can be trained,
The optional keyword `trainable` allows you to specify which fields of your model can be trained,
instead of assuming all `fieldnames(MyModel)` to trainable.
Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.
This can be also done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type.
This can be also be done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type.
The macro also handles overloads of the 3-arg `show(::IO, ::MIME"text/plain", ::MyModel)` for pretty printing.
The optional argument `showtype` can take any of the following values:
The macro also handles overloads of `show` for pretty printing.
It adds methods to `show(::IO, ::MIME"text/plain", ::MyModel)` to treat your layer much like `Dense` or `Chain`.
To opt out of this, use `@layer :ignore MyModel`.
In case, you probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this.
- `:expand` (default): This will expand the representation of container types like `Chain`,
while maintaining a compat representation of types like `Dense` containing only arrays.
- `:noexpand`: This is to be used in case your type contains other layers but you want to keep the representation simple.
- `:ignore`: To opt out of the pretty printing.
You probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this.
Note that re-running the macro with different options may not remove all methods, you will need to restart.
Expand All @@ -33,6 +37,12 @@ Trio(
Dense(1 => 1; bias=false), # 1 parameters
Dropout(0.4),
) # Total: 3 arrays, 4 parameters, 240 bytes.
# Freeze `c`, equivalent to `Optimisers.trainable(tri::Trio) = (; tri.a, tri.b)`
julia> Flux.@layer Trio trainable=(a,b)
# Now the optimizer's state won't contain `c`
julia> opt_state = Flux.setup(Adam(), tri);
```
"""
Expand All @@ -45,8 +55,10 @@ function _layer_macro(exs...)

# These functions are defined in show.jl, and each return an expression overloading Base.show
type, rest... = if exs[1] == QuoteNode(:expand)
@warn "The `:expand` option is deprecated, and will be removed in a future release. Use `@layer` without options instead." maxlog=1
push!(out.args, _macro_big_show(esc(exs[1])))
push!(out.args, _macro_big_show(esc(exs[2])))
exs[2:end]
elseif exs[1] == QuoteNode(:noexpand)
push!(out.args, _macro_layer_show(esc(exs[2])))
exs[2:end]
elseif exs[1] == QuoteNode(:ignore)
exs[2:end]
Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)

@layer LayerNorm
@layer :noexpand LayerNorm

function (a::LayerNorm)(x::AbstractArray)
ChainRulesCore.@ignore_derivatives if a.diag isa Scale
Expand Down
17 changes: 17 additions & 0 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ function _flat_children(x)
gamma = ((beta...)...,)
end

# This is called by @layer :noexpand, on layers which should be treated like Dense, and returns an expression:
function _macro_layer_show(ex)
quote
# Entry point:
function Base.show(io::IO, m::MIME"text/plain", x::$ex)
if !get(io, :compact, false)
_layer_show(io, x)
else
show(io, x)
end
end

# Exit from _big_show recursion:
Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name)
end
end

function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * _layer_string(io, layer)
Expand Down

0 comments on commit 0a679c1

Please sign in to comment.