Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul of ResNet API #174

Merged
merged 67 commits into from
Aug 2, 2022
Merged

Overhaul of ResNet API #174

merged 67 commits into from
Aug 2, 2022

Conversation

theabhirath
Copy link
Member

@theabhirath theabhirath commented Jun 21, 2022

This PR completely re-writes the current ResNet API to make it more powerful, more extensible and to reduce code duplication.

Why this PR?

While making ResNet more fully-featured, this PR will also:

  1. Add support for DropBlock to ResNet, a type of regularisation used in place of Dropout in some networks
  2. Add support for an optional deeper stem front-up
  3. Add support for attention layers
  4. Add support for multiple pooling options in the classifier head
  5. Allow for more ResNet block variants (such as those from the Bag of Tricks paper)

Things to do

  • Abstract out pooling, classifier heads to provide more options
  • Add attention layer
  • Specific constructors for the ResNet paper models
  • Documentation, Documentation and more documentation
    • DropBlock and its behaviour
    • DropPath behaviour in detail (permissible values, calculations)
    • Higher level ResNet interface vs lower level resnet interface
  • Rewrite ResNeXt to use the new resnet API
  • Write more tests
  • Benchmark to make sure there are no regressions
    • Forward pass
    • Backward pass
    • TTFG
  • Figure out a way to port pre-trained weights

Other PRs to land before this one

Miscellaneous fixes

  • Adds a type argument to densenet for nblocks to avoid hitting integer edge cases

@theabhirath
Copy link
Member Author

theabhirath commented Jun 23, 2022

Some perks of the new API:

0.7.2:

julia> model = ResNet(50);

julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took 6.698 s (87.06% GC) to evaluate,
 with a memory estimate of 2.46 GiB, over 47810 allocations.

julia> model = ResNet(18);

julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 2 samples with 1 evaluation.
 Range (min  max):  2.576 s   2.580 s  ┊ GC (min  max): 87.60%  87.65%
 Time  (median):     2.578 s             ┊ GC (median):    87.63%
 Time  (mean ± σ):   2.578 s ± 2.770 ms  ┊ GC (mean ± σ):  87.63% ±  0.03%

  █                                                      █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  2.58 s        Histogram: frequency by time        2.58 s <

 Memory estimate: 1.01 GiB, allocs estimate: 19594.

This PR:

julia> model = ResNet(50);

julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took 5.644 s (85.62% GC) to evaluate,
 with a memory estimate of 2.50 GiB, over 45095 allocations.

julia> model = ResNet(18);

julia> @benchmark Zygote.gradient(p -> sum($model(p)), $x)
BenchmarkTools.Trial: 13 samples with 1 evaluation. 
 Range (min  max):  338.901 ms  612.421 ms  ┊ GC (min  max):  4.01%  46.50%
 Time  (median):     345.959 ms               ┊ GC (median):     5.21%
 Time  (mean ± σ):   416.913 ms ±  90.533 ms  ┊ GC (mean ± σ):  21.52% ± 16.19%

  █▄                            ▁
  ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▆▆▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  339 ms           Histogram: frequency by time          612 ms <

 Memory estimate: 1.03 GiB, allocs estimate: 17275.

Julia version info:

julia> versioninfo()
Julia Version 1.9.0-DEV.840
Commit 68d62ab3d3 (2022-06-22 21:39 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.5.0)
  CPU: 8 × Apple M1
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.5 (ORCJIT, apple-m1)
  Threads: 4 on 4 virtual cores
Environment:
  JULIA_NUM_THREADS = 4

@theabhirath
Copy link
Member Author

How is it that Zygote seems to be getting worse with passing Julia versions, though? I could've sworn it wasn't this bad a couple of weeks ago, and today it seems to be struggling to even calculate a ResNet-50 gradient?

@darsnack
Copy link
Member

I'm confused how the new API contributes to better gradient times (at least for ResNet where there are no new layers added, right)?

A couple high level comments as you work on this:

  • Let's try to make the API more like "pass in what you want" than keywords. What I mean by this is that keywords corresponding to the stem can be eliminated by just having a single stem argument that the user passes in. Then, similar to block, we provide the useful defaults. So someone would pass in resnet_stem(64, :deep) (and we get to allow stuff beyond the defaults...with a "use at your own risk" warning).
  • Try to consolidate as many keyword arguments to into the blocks themselves. So if cardinality is just directly passed to the block, then we don't need to explicitly declare it in resnet. In terms of documentation, the default blocks' docstrings should detail these arguments and resnet defers to those.
  • You can submit a PR to the HuggingFace repo when you are ready to port the weights.

@theabhirath
Copy link
Member Author

theabhirath commented Jun 23, 2022

I'm confused how the new API contributes to better gradient times (at least for ResNet where there are no new layers added, right)?

Well to be completely honest, I'm not sure, but I have some theories that mostly revolve around how the nested Chains are in the two models. But I can't say for sure since I haven't really tried any of that out.

  • Let's try to make the API more like "pass in what you want" than keywords. What I mean by this is that keywords corresponding to the stem can be eliminated by just having a single stem argument that the user passes in. Then, similar to block, we provide the useful defaults.

This makes sense. We could make these NamedTuples or Dicts, maybe? It's a lot easier to keep track of names instead of the order of arguments to be passed in, and it's annoying to have to pass in five irrelevant arguments just to get to the last relevant one.

  • Try to consolidate as many keyword arguments to into the blocks themselves. So if cardinality is just directly passed to the block, then we don't need to explicitly declare it in resnet. In terms of documentation, the default blocks' docstrings should detail these arguments and resnet defers to those.

👍🏽

  • You can submit a PR to the HuggingFace repo when you are ready to port the weights.

This...might take time. The major issue is in terms of getting the model structures to overlap. DropBlock and DropPath add some functionality but at the cost of having some extra identitys in the model for the default cases, which is why I was looking at FluxML/Flux.jl#2004. I'll try and see what I can do, though

@darsnack
Copy link
Member

darsnack commented Jun 23, 2022

We could make these NamedTuples or Dicts, maybe?

I was thinking even more declarative. Just have a function called resnet_stem or whatever you think is appropriate. So like resnet(stem = resnet_stem(mode = :tiered), ...) where resnet_stem actually returns the stem just like basicblock actually returns the block. This would allow for the same functionality as the keywords or named tuple, while also allowing the stem to be any model (so as flexible as possible).

I find these kinds of declarative interfaces are more flexible and easier to keep track of mentally. But they usually take more typing. Possibly we can merge your idea with this and allow a named tuple or Dict too. Then that dispatch should pass the pairs of the named tuple into resnet_stem by default. You would have some intermediate _make_stem(::NamedTuple) / _make_stem(x) that is called by the builder so that you can dispatch on the type.

It's a lot easier to keep track of names instead of the order of arguments to be passed in, and it's annoying to have to pass in five irrelevant arguments just to get to the last relevant one.

Yeah, I think the resnet_stem function does not need to be positional. It can accept keywords for this purpose.

The major issue is in terms of getting the model structures to overlap. DropBlock and DropPath add some functionality but at the cost of having some extra identitys in the model for the default cases, which is why I was looking at FluxML/Flux.jl#2004.

No hurry on this. Also, the script linked in the HuggingFace model cards doesn't depend on structure. It turns the Flux model into a state dict-like dictionary then just iterates the keys together with the PyTorch state dict. It might just work for your model since the DropBlock stuff does not contain parameters that would affect the state dict.

@theabhirath
Copy link
Member Author

theabhirath commented Jun 25, 2022

  • Try to consolidate as many keyword arguments to into the blocks themselves. So if cardinality is just directly passed to the block, then we don't need to explicitly declare it in resnet. In terms of documentation, the default blocks' docstrings should detail these arguments and resnet defers to those.

I was trying to come up with a more declarative API, but one of the problems that we might face is in terms of documentation. Since these blocks have a lot of arguments, directing end-users to refer to the documentation for these blocks might cause some confusion. I'm a little uncertain if that's desirable. Maybe we keep the declarative API but document the kwargs one level higher anyways?

This also causes quite a bit of argument hiding (i.e. builder functions aren't explicitly accepting the arguments to be passed to the lower level ones but instead a NamedTuple or Dict) which I'm not sure is the right way to go. The API becomes a little less clearer for both end users and package developers

@darsnack
Copy link
Member

builder functions aren't explicitly accepting the arguments to be passed to the lower level ones but instead a NamedTuple or Dict

Don't they just accept something like block_args...?

Maybe we keep the declarative API but document the kwargs one level higher anyways?

That's okay

@theabhirath
Copy link
Member Author

Don't they just accept something like block_args...

I could, but this has the same problem - the function doesn't clearly "see" the kwargs being passed in, which means we're essentially banking on users of this function to play nice. Which is fine, but it feels kinda wrong to have too many functions where the inputs aren't regulated

@darsnack
Copy link
Member

This is a natural conflict between designing something to be flexible vs. safe. In general, Julia code tries to be more permissive, especially at the lower level API. This is what makes it possible to smash together two totally separate packages and get a useful result without too much hacking. This approach definitely requires more care, and I find the best way to work through this is to just try and be permissive until you hit a roadblock. Usually that experience is most informative about the design space. Let me try and walk through some of that process below.

it feels kinda wrong to have too many functions where the inputs aren't regulated

They are regulated, just not by resnet. For example, take the cardinality keyword and suppose it does not have a default. Then the user is expected to specify it.

  1. If resnet explicitly has this keyword, then the user who fails to specify it will get a MethodError for the resnet method with the cardinality keyword highlighted as missing.
  2. If resnet does not have this keyword, then the user who fails to specify it will get a MethodError for the basicblock (or whatever block_fn) method with the cardinality keyword highlighted as missing.

Either way, the user gets the same error. A similar outcome will happen for invalid keywords (with a slightly more informative error too) or for positional arguments. I would argue that getting the error for (2) is more informative because it signals that it is specifically the block_fn and the keyword that is incompatible. For (1), if we choose to restrict block_fn, then we can validate the arguments and provide even more informative errors, but this means that the entire API design is less flexible. If arguments are meant to be in specified ranges, etc. then that kind of assertion should happen inside block_fn where the stack trace itself is as informative as possible about the location of the error.

Maybe there is a specific kind of error that you are expecting that isn't covered well here? We should discuss that case in more detail then. Also, remember that this is a fairly low-level portion of the API. There is an expectation that the user can read Julia errors here (i.e. not the same level as ResNet where we want to be very beginner friendly). I recommend reading oxinabox's post on Julia anti-patterns and specifically the section on over-specifying argument types for an intro to this "letting the error get thrown eventually" philosophy (relatedly, this post is also good here). Where we want to intercept errors early are cases where we expect the average user to be a beginner, or where the resulting default error is misleading or cryptic.


Documenting the interface is a related but different concern. The interface should appear intuitive by itself.

I would argue that many specified but restricted keywords is not intuitive. It requires reading the docstring to understand the behavior and how each one is used / when each is ignored. conv_bn is the perfect example of a poor keyword interface (which we've let go because it is very internal).

On the other hand, saying "arguments passed to block_fn" is unspecified but it clearly follows from your understanding of block_fn. If you are customizing the block_fn, then you must understand these keywords and their usage in block_fn no matter what the design is. But if you don't care about block_fn and want the default, then there is nothing for you to read and the interface naturally lets you ignore the keywords. Similarly, if you customize block_fn, then you can feel assured that the extraneous keywords for another block are irrelevant (and you are not forced to implement block_fn in a way that adheres to an interface that doesn't apply to you).

Here's an attempt at the docstring. Let me know what you think (and feel free to push back!). Of course, this would also require similar changes to _make_blocks to be more declarative too. This means basically factoring out code that _make_blocks does automatically into separate functions that can be passed in (e.g. the downsampling function doesn't need to be instantiated in _make_blocks and could be constructed and passed in as a single downsampler argument).

"""
    resnet(block, layers, stem = somedefault(); nclasses = 1000, inchannels = 3, output_stride = 32,
           reduce_first = 1, activation = relu,
           norm_layer = BatchNorm, drop_rate = 0.0,
           block_kwargs...)
Creates the layers of a ResNe(X)t model. If you are an end-user, you should probably use
[ResNet](@ref) instead and pass in the parameters you want to modify as optional parameters
there.
# Arguments:
  - `block` / `block_kwargs`: The residual block to use in the model and the keyword arguments for it. See [basicblock](@ref) and [bottleneck](@ref) for
    example. This is called like `block(inplanes, outplanes; stride, block_kwargs...)`.
  - `layers`: A list of integers representing the number of blocks in each stage.
  - `stem`: The initial stage that operates on the input before the residual blocks. This can be any model that accepts the input and is compatible with the blocks stage. Defaults to [`somedefault`](#).
  - `nclasses`: The number of output classes. The default value is 1000.
  - `inchannels`: The number of input channels to the model. The default value is 3.
  - `output_stride`: The net stride of the model. Must be one of [8, 16, 32]. The default value is 32.
  - `reduce_first`: Reduction factor for first convolution output width of residual blocks,
    Default is 1 for all architectures except SE-Nets, where it is 2.
  - `activation`: The activation function to use. The default value is `relu`.
  - `norm_layer`: The normalization layer to use. The default value is `BatchNorm`.
  - `drop_rate`: The rate to use for `Dropout` before the fully-connected classifier stage. The default value is 0.0.
If you are an end-user trying to tweak the ResNet model, note that there is no guarantee that
all combinations of parameters will work. In particular, tweaking `block_kwargs` is not
advised unless you know what you are doing.
"""

I think the line: "This is called like block(inplanes, outplanes; stride, block_kwargs...)" is clear about the usage of block_kwargs. It's unspecified, but it is validated (when block itself is called) and it is clear that I need to know what keywords block accepts to understand this. At a higher level, like ResNet, I might think about also including another section "# Block arguments" that explains the standard arguments for basicblock and bottleneck.

@theabhirath
Copy link
Member Author

theabhirath commented Jun 25, 2022

Thank you for that writeup, it does clear some stuff up! I might need to do some homework before I get back with a response, but the two blog posts in particular might be good starting points in terms of understanding programming patterns in Julia a little better. I think most of my worry revolves around making ResNet safe - if someone is using resnet I'm reasonably certain they know what they're doing. Right now ResNet is just a thin wrapper around resnet, though, so the documentation and the interaction is something I am trying to get a cleaner picture of as this PR shapes up

@theabhirath
Copy link
Member Author

Okay, I've just pushed what I think is a more declarative interface (and it does look cleaner from the user's POV). This mostly revolves around exposing two arguments at the resnet level (and lower level builder functions as well): a *_fn and a *_args for the downsample block, the model stem and the main block. The *_fn is to allow flexibility around what choice to use, and the *_args is a NamedTuple for passing in arguments to the *_fn.

I'm planning to rigorously document the choices of *_fn and *_args at the resnet level. So in this scheme, there are three "levels" we are catering to:

  1. End users who don't really care about experimental choices and just want to be able to instantiate a ResNet quickly without having to sift through a lot of complicated documentation. Currently, this is easy enough because ResNet will not have a lot of documentation around it. It will redirect the advanced user to consider resnet instead.
  2. Advanced users and writers of packages that depend on Metalhead.jl - for this level, the documentation surrounding resnet should be enough to try various experimental options without breaking stuff (of course, we will not guarantee this 😄 ).
  3. Metalhead.jl devs - unfortunately enough, we really do need to know exactly how every function works 😂. At this level, contributors and devs can read through comments and docstrings that I will populate for all these functions explaining practically everything so that there's no confusion on all the possible options.

The docs for this are missing because I wanna make sure that this interface is something that can be agreed upon before I proceed to write it up 😅 Any feedback is welcome!

1. Less keywords for the user to worry about
2. Delete `ResNeXt` just for now
@theabhirath
Copy link
Member Author

Oh no. Did I manage to kill CI altogether somehow?

@theabhirath theabhirath reopened this Jun 29, 2022
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design looks good, mostly minor changes here and there. I've been holding off on doing a full pass through all the other non-ResNet code, so I just did that and most of my comments are in those sections.

src/convnets/convmixer.jl Outdated Show resolved Hide resolved
src/convnets/resnets/core.jl Show resolved Hide resolved
src/convnets/resnets/core.jl Outdated Show resolved Hide resolved
Comment on lines 226 to 227
# inplanes increases by expansion after each block
inplanes = planes * expansion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# inplanes increases by expansion after each block
inplanes = planes * expansion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this, though. This is calculating the change in inplanes across blocks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something but I don't see where the output of this calculation goes? It seems unused...unless it is modifying a global which is very bad.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were before, unfortunately. I've pushed a change. This makes resnet_planes return a vector instead of being a stage_idx based callback - the reason we need this is because inplanes needs the planes from the previous block, not the current one, so we need to have access to that information

src/convnets/resnets/core.jl Outdated Show resolved Hide resolved
src/convnets/resnets/seresnet.jl Outdated Show resolved Hide resolved
src/layers/conv.jl Outdated Show resolved Hide resolved
src/layers/conv.jl Outdated Show resolved Hide resolved
src/layers/pool.jl Outdated Show resolved Hide resolved
test/convnets.jl Outdated Show resolved Hide resolved
@theabhirath
Copy link
Member Author

I've incorporated some of the docs changes, and left out the others - these will need a thorough once-over anyways, and I want to try and get those in at the same time as the devdocs and the Documenter.jl port

Co-Authored-By: Kyle Daruwalla <[email protected]>
@theabhirath theabhirath force-pushed the resnet-plus branch 2 times, most recently from d1d193a to 07c5c64 Compare July 29, 2022 17:53
Also misc. formatting and cleanup
@theabhirath
Copy link
Member Author

I've also added Wide ResNet now (easy enough). But the CI is weird. I think my filtering should work but the ResNet testset isn't executing at all

@theabhirath
Copy link
Member Author

Bump?

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks done to me. I just caught a couple last doc fixes and tests.

src/convnets/resnets/core.jl Outdated Show resolved Hide resolved
Comment on lines 226 to 227
# inplanes increases by expansion after each block
inplanes = planes * expansion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something but I don't see where the output of this calculation goes? It seems unused...unless it is modifying a global which is very bad.

src/convnets/resnets/core.jl Show resolved Hide resolved
src/convnets/resnets/core.jl Show resolved Hide resolved
src/convnets/resnets/core.jl Show resolved Hide resolved
return Chain(stages...)
end

function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, connection,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring for each resnet?

Comment on lines +89 to +91
function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu;
norm_layer = BatchNorm, revnorm = false,
use_norm = (true, true), stride = 1, kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this needs a docstring update

test/convnets.jl Show resolved Hide resolved
.github/workflows/CI.yml Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @theabhirath! This is a HUGE improvement, so I appreciate all the time you put into it. I'm gonna let tests run to completion.

@theabhirath
Copy link
Member Author

Yeah I think this is the longest PR in terms of review comments on this repo, but it thoroughly deserved the discussion 😄 Happy to see this one through

@theabhirath
Copy link
Member Author

theabhirath commented Aug 2, 2022

I've also now made PRs to the HuggingFace repositories for the models. Once they're accepted, I'll push the updated pretrained weights links and SHAs as well. It would be good to have all the tests enabled and all the tasks ticked off 😄

@theabhirath
Copy link
Member Author

I've also now made PRs to the HuggingFace repositories for the models. Once they're accepted, I'll push the updated pretrained weights links and SHAs as well. It would be good to have all the tests enabled and all the tasks ticked off 😄

On second thoughts, might not want this to block the PR....I want to try and use the updated torchvision weights with higher accuracies - there's been some API changes so this may take a little more time

@darsnack darsnack merged commit 7e4f9db into FluxML:master Aug 2, 2022
@theabhirath theabhirath deleted the resnet-plus branch August 2, 2022 13:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

3 participants