-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support TensorModule of distinct input/output types #34
Comments
After doing a bit of research it seems like So a couple of thoughts
|
That's good to know. One thing to keep in mind is that the included models are mostly basic building blocks, but I'm not sure how inputs/output structures and types look like with more complex custom modules. Another good reason to implement a few more architectures like transformers to get a better feeling for it. :)
👍
Would be interesting if it is possible to create an easy to use API for a) One general thought regarding the parameter types of modules is that we need to consider their recursive structure and mutability. Here's what I said in another discussion about that:
So perhaps we need to provide an immutable module API and make sure to copy the module and all its submodules recursively to make this safe or perhaps you have another idea how to deal with it. |
Currently
TensorModule
is parametrized on a single type, so keeps the transformation within the same DType:However there are modules where the input might be different than the output, such as
nn.Embedding
which accepts Int or Longs as input indexes, and the output could be any DTtype. So the trait would need to be parametrized like this:and the example implementation would be something like:
This is doable, however there are useful operators on TensorModule, such as
nn.Sequential
, which expects an array of modules to chain. By having a single parameter the compile time validation is straightforward, but having distinct input/output types things seem to get a bit more complex to validate at compile time.I will do some research on this on how to solve it.
Any pointers or ideas are more than welcome!
The text was updated successfully, but these errors were encountered: