diff --git a/.Rbuildignore b/.Rbuildignore index 56af93d107..b02425384a 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -40,14 +40,14 @@ mnist-r.* # ^vignettes/using-autograd\.Rmd # uncomment below for CRAN submission -^inst/bin/.* -^inst/include/(?!torch.h|lantern|torch_RcppExports.h|utils.h|torch_impl.h|torch_types.h|torch_api.h|torch_deleters.h|torch_imports.h).* -^inst/lib/.* -^inst/share/.* -^inst/build-hash -^inst/build-versions -^src/lantern/.* -^tests/testthat/assets/model-v.* +# ^inst/bin/.* +# ^inst/include/(?!torch.h|lantern|torch_RcppExports.h|utils.h|torch_impl.h|torch_types.h|torch_api.h|torch_deleters.h|torch_imports.h).* +# ^inst/lib/.* +# ^inst/share/.* +# ^inst/build-hash +# ^inst/build-versions +# ^src/lantern/.* +# ^tests/testthat/assets/model-v.* ^doc$ ^Meta$ diff --git a/R/indexing.R b/R/indexing.R index cd6f6fdbf4..3f8c27ce12 100644 --- a/R/indexing.R +++ b/R/indexing.R @@ -50,7 +50,7 @@ print.slice <- function(x, ...) { N = .Machine$integer.max, newaxis = NULL, `..` = structure(list(), class = "fill") -) +) tensor_slice <- function(tensor, ..., drop = TRUE) { Tensor_slice(tensor, environment(), drop = drop, mask = .d) diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 9869d72948..8554f6620e 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -44146,11 +44146,11 @@ BEGIN_RCPP END_RCPP } // Tensor_slice_put -void Tensor_slice_put(Rcpp::XPtr self, Rcpp::Environment e, SEXP rhs, Rcpp::List mask); +void Tensor_slice_put(XPtrTorchTensor self, Rcpp::Environment e, SEXP rhs, Rcpp::List mask); RcppExport SEXP _torch_Tensor_slice_put(SEXP selfSEXP, SEXP eSEXP, SEXP rhsSEXP, SEXP maskSEXP) { BEGIN_RCPP Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< Rcpp::XPtr >::type self(selfSEXP); + Rcpp::traits::input_parameter< XPtrTorchTensor >::type self(selfSEXP); Rcpp::traits::input_parameter< Rcpp::Environment >::type e(eSEXP); Rcpp::traits::input_parameter< SEXP >::type rhs(rhsSEXP); Rcpp::traits::input_parameter< Rcpp::List >::type mask(maskSEXP); diff --git a/src/indexing.cpp b/src/indexing.cpp index ffd813595f..00b313f022 100644 --- a/src/indexing.cpp +++ b/src/indexing.cpp @@ -205,7 +205,7 @@ struct index_info { // returns true if appended a vector like object. We use the boolean vector // to decide if we should start a new index object. index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice, - bool drop) { + bool drop, torch::Device device) { // a single NA means empty argument which and in turn we must select // all elements in that dimension. if (TYPEOF(slice) == LGLSXP && LENGTH(slice) == 1 && @@ -249,13 +249,26 @@ index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice, // if it's a numeric vector if ((TYPEOF(slice) == REALSXP || TYPEOF(slice) == INTSXP) && LENGTH(slice) > 1) { - index_append_integer_vector(index, slice); - return {1, true, false}; + // if it's a numeric vector but has a dim attribute, we convert the value to a Tensor + // before adding it to the index. + const auto dims = Rcpp::RObject(Rf_getAttrib(slice, R_DimSymbol)); + if (Rf_isNull(dims)) { + index_append_integer_vector(index, slice); + return {1, true, false}; + } + // If the slice has a dim attribute, we convert it to a tensor and let the code + // continue to add it to the index. + slice = torch_tensor_cpp(slice, torch::Dtype(lantern_Dtype_int64()), device); } if (TYPEOF(slice) == LGLSXP) { - index_append_bool_vector(index, slice); - return {1, true, false}; + const auto dims = Rcpp::RObject(Rf_getAttrib(slice, R_DimSymbol)); + if (Rf_isNull(dims)) { + index_append_bool_vector(index, slice); + return {1, true, false}; + } + /// convert to tensor a let it go + slice = torch_tensor_cpp(slice, torch::Dtype(lantern_Dtype_bool())); } if (Rf_inherits(slice, "torch_tensor")) { @@ -271,7 +284,7 @@ index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice, } std::vector slices_to_index( - std::vector slices, bool drop) { + std::vector slices, bool drop, torch::Device device) { std::vector output; XPtrTorchTensorIndex index = lantern_TensorIndex_new(); SEXP slice; @@ -279,7 +292,7 @@ std::vector slices_to_index( bool has_ellipsis = false; for (auto i = 0; i < slices.size(); i++) { slice = slices[i]; - auto info = index_append_sexp(index, slice, drop); + auto info = index_append_sexp(index, slice, drop, device); if (!has_ellipsis && info.ellipsis) { has_ellipsis = true; @@ -328,7 +341,8 @@ std::vector slices_to_index( XPtrTorchTensor Tensor_slice(XPtrTorchTensor self, Rcpp::Environment e, bool drop, Rcpp::List mask) { auto dots = evaluate_slices(enquos0(e), mask); - auto index = slices_to_index(dots, drop); + auto device = torch::Device(lantern_Tensor_device(self.get())); + auto index = slices_to_index(dots, drop, device); XPtrTorchTensor out = self; for (auto& ind : index) { out = lantern_Tensor_index(out.get(), ind.get()); @@ -339,10 +353,11 @@ XPtrTorchTensor Tensor_slice(XPtrTorchTensor self, Rcpp::Environment e, XPtrTorchScalar cpp_torch_scalar(SEXP x); // [[Rcpp::export]] -void Tensor_slice_put(Rcpp::XPtr self, Rcpp::Environment e, +void Tensor_slice_put(XPtrTorchTensor self, Rcpp::Environment e, SEXP rhs, Rcpp::List mask) { auto dots = evaluate_slices(enquos0(e), mask); - auto indexes = slices_to_index(dots, true); + auto device = torch::Device(lantern_Tensor_device(self.get())); + auto indexes = slices_to_index(dots, true, device); if (indexes.size() > 1) { Rcpp::stop( @@ -356,13 +371,13 @@ void Tensor_slice_put(Rcpp::XPtr self, Rcpp::Environment e, TYPEOF(rhs) == LGLSXP || TYPEOF(rhs) == STRSXP) && LENGTH(rhs) == 1) { auto s = cpp_torch_scalar(rhs); - lantern_Tensor_index_put_scalar_(self->get(), index.get(), s.get()); + lantern_Tensor_index_put_scalar_(self.get(), index.get(), s.get()); return; } if (Rf_inherits(rhs, "torch_tensor")) { Rcpp::XPtr t = Rcpp::as>(rhs); - lantern_Tensor_index_put_tensor_(self->get(), index.get(), t->get()); + lantern_Tensor_index_put_tensor_(self.get(), index.get(), t->get()); return; } diff --git a/src/lantern/src/Tensor.cpp b/src/lantern/src/Tensor.cpp index cf122e8d80..6e26e4f16a 100644 --- a/src/lantern/src/Tensor.cpp +++ b/src/lantern/src/Tensor.cpp @@ -156,7 +156,6 @@ bool *_lantern_Tensor_data_ptr_bool(void *self) { int64_t _lantern_Tensor_numel(void *self) { LANTERN_FUNCTION_START torch::Tensor x = from_raw::Tensor(self); - ; return x.numel(); LANTERN_FUNCTION_END_RET(0) } @@ -164,7 +163,6 @@ int64_t _lantern_Tensor_numel(void *self) { int64_t _lantern_Tensor_element_size(void *self) { LANTERN_FUNCTION_START torch::Tensor x = from_raw::Tensor(self); - ; return x.element_size(); LANTERN_FUNCTION_END_RET(0) } @@ -172,7 +170,6 @@ int64_t _lantern_Tensor_element_size(void *self) { int64_t _lantern_Tensor_ndimension(void *self) { LANTERN_FUNCTION_START torch::Tensor x = from_raw::Tensor(self); - ; return x.ndimension(); LANTERN_FUNCTION_END_RET(0) } @@ -180,7 +177,6 @@ int64_t _lantern_Tensor_ndimension(void *self) { int64_t _lantern_Tensor_size(void *self, int64_t i) { LANTERN_FUNCTION_START torch::Tensor x = from_raw::Tensor(self); - ; return x.size(i); LANTERN_FUNCTION_END_RET(0) } @@ -188,7 +184,6 @@ int64_t _lantern_Tensor_size(void *self, int64_t i) { void *_lantern_Tensor_dtype(void *self) { LANTERN_FUNCTION_START torch::Tensor x = from_raw::Tensor(self); - ; torch::Dtype dtype = c10::typeMetaToScalarType(x.dtype()); return make_raw::Dtype(dtype); LANTERN_FUNCTION_END @@ -197,7 +192,6 @@ void *_lantern_Tensor_dtype(void *self) { void *_lantern_Tensor_device(void *self) { LANTERN_FUNCTION_START torch::Tensor x = from_raw::Tensor(self); - ; torch::Device device = x.device(); return make_raw::Device(device); LANTERN_FUNCTION_END @@ -238,7 +232,6 @@ void *_lantern_Tensor_names(void *self) { bool _lantern_Tensor_has_any_zeros(void *self) { LANTERN_FUNCTION_START torch::Tensor x = from_raw::Tensor(self); - ; return (x == 0).any().item().toBool(); LANTERN_FUNCTION_END } diff --git a/tests/testthat/test-indexing.R b/tests/testthat/test-indexing.R index 39dff44dbd..8691fc3f34 100644 --- a/tests/testthat/test-indexing.R +++ b/tests/testthat/test-indexing.R @@ -266,3 +266,41 @@ test_that("NULL tensor", { expect_error(torch_tensor(as.integer(NULL))[1], regexp = "out of bounds") }) + +test_that("works with numeric /logic matrix", { + # Regression test for: https://github.com/mlverse/torch/issues/1181 + x <- torch_randn(4, 4) + y <- rbind(c(1, 1), c(1,2)) + + expect_true( + torch_allclose( + x[y], + x[torch_tensor(y, dtype = "long")] + ) + ) + + expect_true( + torch_allclose( + x[x > 0], + x[as.array(x>0)] + ) + ) + + # also test if it works when the tensor is in a different device + skip_if_not_m1_mac() + x <- x$to(device="mps") + + expect_true( + torch_allclose( + x[y], + x[torch_tensor(y, dtype = "long")] + ) + ) + + expect_true( + torch_allclose( + x[x > 0], + x[as.array(x>0)] + ) + ) +})