Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNNs redesign #2500

Merged
merged 4 commits into from
Nov 4, 2024
Merged

RNNs redesign #2500

merged 4 commits into from
Nov 4, 2024

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Oct 14, 2024

A complete rework of our recurrent layers, making them more similar to their pytorch counterpart.
This is in line with the proposal in #1365 and should allow to hook into the cuDNN machinery (future PR).
Hopefully, this ends the infinite source of troubles that the recurrent layers have been.

  • Recur is no more. Mutating its internal state was a source of problems for AD (explicit differentiation for RNN gives wrong results #2185)
  • Now RNNCell is exported and takes care of the minimal recursion step, i.e. a single time:
    • has forward cell(x , h)
    • x can be of size in or in x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out or out x batch_size
  • RNN instead takes in a (batched) sequence and a (batched) hidden state and returns the hidden state for the whole sequence:
    • has forward rnn(x, h)
    • x can be of size in x len or in x len x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out x len or out x len x batch_size
  • LSTM and GRU are similarly changed.

Close #2185, close #2341, close #2258, close #1547, close #807, close #1329

Related to #1678

PR Checklist

  • cpu tests
  • gpu tests
  • if hidden state not given as input, assumed to be zero
  • port LSTM and GRU
  • Entry in NEWS.md
  • Remove reset!
  • Docstrings
  • revisit documentation
  • add an option in constructors to have trainable initial state? (future PR)
  • use cuDNN (future PR)
  • implement the num_layers argument for stacked RNNs (future PR)
  • add dropout (future PR)
  • add bidirectional option (future PR)

@darsnack
Copy link
Member

darsnack commented Oct 18, 2024

Fully agree with updating the design to be non-mutating. There are two options we've discussed in the past:

  1. y, h = cell(x, h) like here (I guess this PR removes y as a return value which is fine)
  2. y, cell = cell(x) / y, cell = Flux.apply(cell, x)

Option 1 is outlined in this PR so I won't say anything about it.

Option 2 is a more drastic redesign to make all layers (not just recurrent) non-mutating. Why?

  • Do a design that covers stateful layers in general (e.g. norm layers) and not just recurrent cells
  • Keep a nice feature of Flux's current design which is that the model contains all info: parameters, state, flags, etc.

@CarloLucibello
Copy link
Member Author

I thought about Option 2. On the upside, it seems a nice intermediate spot between current Flux and Lux. The downside is that the interface would seem a bit exotic to flux and pytorch users. Moreover, it would be problematic for normalization layers.

Also, we need to distinguish between normalization layers and recurrent layers.

  • Normalization layer at training time update some internal buffers, within a stopgrad barrier. The buffer update has no influence on the output of the layer and the final loss. You typically apply the layer only once during the forward pass. Normalization layers are typically part of larger models (chains or custom structs). Therefore for normalization layers: 1) we haven't had the gradient computation problems we had for recurrent layers; 2) you want the layer with the updated buffer to be inserted back in your model, but this would require a mutating operation or returning a new model.

  • For recurrent layers, Option 2 would be sensible, but is it worth it? Once you adopt the perspective that a cell takes two inputs, x and h, and gives back an output, hnew, all problems disappear. I think we add complexity for no gain in trying to keep the state internal.

@ToucheSir
Copy link
Member

The main benefit for keeping the state "internal" or having it be part of a unified interface like apply would be that Chain works with RNNs again. Whether that's worth the extra complexity is the question. Given our priorities, I think it's best left as future work.

finish RNNCell

RNN rework

LSTMCell

LSTM

more work

gru

extended testing

runtests

add tests

finish RNNCell

RNN rework

LSTMCell

LSTM

more work

gru

extended testing

reset! deprecation

fix test

unbreak l2 test

fix tests

fixes
Copy link

codecov bot commented Oct 22, 2024

Codecov Report

Attention: Patch coverage is 76.92308% with 30 lines in your changes missing coverage. Please review.

Project coverage is 34.93%. Comparing base (c9bab66) to head (76cf275).
Report is 10 commits behind head on master.

Files with missing lines Patch % Lines
src/layers/recurrent.jl 78.74% 27 Missing ⚠️
src/deprecations.jl 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2500      +/-   ##
==========================================
+ Coverage   33.46%   34.93%   +1.46%     
==========================================
  Files          31       31              
  Lines        1829     1878      +49     
==========================================
+ Hits          612      656      +44     
- Misses       1217     1222       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@CarloLucibello
Copy link
Member Author

I think this is ready.

@CarloLucibello
Copy link
Member Author

I will merge this tomorrow if there are no further comments or objections.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment