-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing basic RNN #162
base: main
Are you sure you want to change the base?
Conversation
Amazing, thanks and great to see you, @castelao, after so many years. 🙂 Will review this coming week. |
@milancurcic , yes, it's great to connect again. Thank you and the other developers for your time in this library! It is great. I'm not fluent in modern Fortran, so if you see anything that doesn't make sense, please let me know. And be aware, it is still a WIP. |
a9111b3
to
3fa5281
Compare
I added the support for |
@milancurcic , I have to check how you updated the library since the last time I worked on this and see if what I did is still consistent. If it looks fine, I intend to work on the following:
Are there any other requirements before I can submit this PR for review? It will be slow progress, but I'm back to this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I am currently working with neural-fortran
, I did a quick review of this PR and left a few minor comments. Overall LGTM. Thank you.
procedure :: get_params | ||
procedure :: init | ||
procedure :: set_params | ||
! procedure :: set_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
! procedure :: set_state |
!module subroutine set_state(self, state) | ||
! type(rnn_layer), intent(inout) :: self | ||
! real, intent(in), optional :: state(:) | ||
!end subroutine set_state | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!module subroutine set_state(self, state) | |
! type(rnn_layer), intent(inout) :: self | |
! real, intent(in), optional :: state(:) | |
!end subroutine set_state |
db = gradient * self % activation % eval_prime(self % z) | ||
dw = matmul(reshape(input, [size(input), 1]), reshape(db, [1, size(db)])) | ||
self % gradient = matmul(self % weights, db) | ||
self % dw = self % dw + dw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recently modified these lines in nf_dense_layer_submodule.f90 for better performances. The same logic could be done here IMO.
class(rnn_layer), intent(in) :: self | ||
real, allocatable :: params(:) | ||
|
||
params = [ & |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pack
can be avoided here, by using pointers or because it is not needed. See here for some changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment for the subroutines get_gradients
and set_params
below.
@castelao Thanks for all the additions. Sounds good.
|
@jvdp1 , thanks for your suggestions. I'll work on that. @milancurcic , yes, I have already rebased it with |
The dimensions don't match, but let's start with something that compile.
Note a hardcoded 'simple_rnn_cell_23' that must be resolved later.
I'll try with 1D with a state memory and the option to reset state for processing a new time series.
Each neuron is affected by all states. With this change the forward procedure is working correctly. I verified a couple of test cases.
Instead of reset on network level.
Previously `quadratic_derivative`.
Co-authored-by: Jeremie Vandenplas <[email protected]>
A work in progress. I'm mostly interested in loading a TF model from HDF and applying predict(), but I'll do my best in doing a complete implementation coherent with the rest of the library.