From 4c93f8f6bb8b63b9c574ab8c097b51a9202a9f80 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 29 Aug 2023 17:01:36 -0300 Subject: [PATCH] Don't allow `NULLs` to be propagated in the dataloader. --- R/utils-data-fetcher.R | 10 ++----- src/utils.cpp | 7 +++-- tests/testthat/test-utils-data-dataloader.R | 32 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/R/utils-data-fetcher.R b/R/utils-data-fetcher.R index 5e37b66433..109ccca74c 100644 --- a/R/utils-data-fetcher.R +++ b/R/utils-data-fetcher.R @@ -36,21 +36,17 @@ IterableDatasetFetcher <- R6::R6Class( d <- self$dataset_iter() if (is_exhausted(d)) { - if (self$drop_last) { + if (self$drop_last || i == 1) { return(coro::exhausted()) } + # we drop the null values in that list. + data <- data[seq_len(i-1L)] break } data[[i]] <- d } - - # no data for the next batch, we return exhausted before trying anything - if (i == 1) { - return(coro::exhausted()) - } - } else { data <- self$dataset_iter() } diff --git a/src/utils.cpp b/src/utils.cpp index 8d95093ce5..850452ec98 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -65,9 +65,9 @@ std::thread::id main_thread_id() noexcept { // [[Rcpp::export]] Rcpp::List transpose2(Rcpp::List x) { auto templ = Rcpp::as(x[0]); - auto num_elements = templ.length(); + const auto num_elements = templ.length(); - auto size = x.length(); + const auto size = x.length(); std::vector out; for (auto i = 0; i < num_elements; i++) { @@ -75,6 +75,9 @@ Rcpp::List transpose2(Rcpp::List x) { } for (size_t j = 0; j < size; j++) { + if (Rf_isNull(x[j])) { + Rcpp::stop("NULL is not allowed. Expected a list."); + } auto el = Rcpp::as(x[j]); for (auto i = 0; i < num_elements; i++) { out[i][j] = el[i]; diff --git a/tests/testthat/test-utils-data-dataloader.R b/tests/testthat/test-utils-data-dataloader.R index d8791e105e..fe86ecb104 100644 --- a/tests/testthat/test-utils-data-dataloader.R +++ b/tests/testthat/test-utils-data-dataloader.R @@ -590,3 +590,35 @@ test_that("correctly reports length for iterable datasets that provide length", expect_equal(length(dl), 1) }) + +test_that("a case that errors in luz", { + + get_iterable_ds <- iterable_dataset( + "iterable_ds", + initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) { + self$len <- len + self$x <- torch::torch_randn(size = c(len, x_size)) + self$y <- torch::torch_randn(size = c(len, y_size)) + }, + .iter = function() { + i <- 0 + function() { + i <<- i + 1 + + if (i > self$len) { + return(coro::exhausted()) + } + + list( + x = self$x[i,..], + y = self$y[i,..] + ) + } + } + ) + + ds <- get_iterable_ds() + dl <- dataloader(ds, batch_size = 32) + expect_equal(length(coro::collect(dl)), 4) + +})