Skip to content

Commit

Permalink
Solved out_features typo in nn_linear.
Browse files Browse the repository at this point in the history
The public member of `nn_linear` for the output features should be called `out_features` instead of `out_feature` (s was missing).
Solved the typo and updated `test-nn.R` to not fail with the new name.
  • Loading branch information
nachodieez committed Sep 6, 2023
1 parent 5157860 commit 3574393
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions R/nn-linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ nn_identity <- nn_module(
nn_linear <- nn_module(
"nn_linear",
initialize = function(in_features, out_features, bias = TRUE) {
self$in_features <- in_features
self$out_feature <- out_features
self$in_features <- in_features
self$out_features <- out_features

self$weight <- nn_parameter(torch_empty(out_features, in_features))
if (bias) {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-nn.R
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ test_that("$<- works for instances", {
expect_s3_class(model, "nn_module")
model$mymodule <- nn_linear(2, 2)
expect_s3_class(model, "nn_module")
expect_equal(model$mymodule$out_feature, 2)
expect_equal(model$mymodule$out_features, 2)
model$new_module <- nn_linear(5, 5)
expect_s3_class(model, "nn_module")

Expand All @@ -333,7 +333,7 @@ test_that("[[<- works for instances", {
expect_s3_class(model, "nn_module")
model[["mymodule"]] <- nn_linear(2, 2)
expect_s3_class(model, "nn_module")
expect_equal(model$mymodule$out_feature, 2)
expect_equal(model$mymodule$out_features, 2)
model[["new_module"]] <- nn_linear(5, 5)
expect_s3_class(model, "nn_module")

Expand Down

0 comments on commit 3574393

Please sign in to comment.