Skip to content

Commit

Permalink
Don't allow NULLs to be propagated in the dataloader.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Aug 29, 2023
1 parent 12d1378 commit 4c93f8f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
10 changes: 3 additions & 7 deletions R/utils-data-fetcher.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,17 @@ IterableDatasetFetcher <- R6::R6Class(
d <- self$dataset_iter()

if (is_exhausted(d)) {
if (self$drop_last) {
if (self$drop_last || i == 1) {
return(coro::exhausted())
}

# we drop the null values in that list.
data <- data[seq_len(i-1L)]
break
}

data[[i]] <- d
}

# no data for the next batch, we return exhausted before trying anything
if (i == 1) {
return(coro::exhausted())
}

} else {
data <- self$dataset_iter()
}
Expand Down
7 changes: 5 additions & 2 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,19 @@ std::thread::id main_thread_id() noexcept {
// [[Rcpp::export]]
Rcpp::List transpose2(Rcpp::List x) {
auto templ = Rcpp::as<Rcpp::List>(x[0]);
auto num_elements = templ.length();
const auto num_elements = templ.length();

auto size = x.length();
const auto size = x.length();
std::vector<Rcpp::List> out;

for (auto i = 0; i < num_elements; i++) {
out.push_back(Rcpp::List(size));
}

for (size_t j = 0; j < size; j++) {
if (Rf_isNull(x[j])) {
Rcpp::stop("NULL is not allowed. Expected a list.");
}
auto el = Rcpp::as<Rcpp::List>(x[j]);
for (auto i = 0; i < num_elements; i++) {
out[i][j] = el[i];
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test-utils-data-dataloader.R
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,35 @@ test_that("correctly reports length for iterable datasets that provide length",
expect_equal(length(dl), 1)

})

test_that("a case that errors in luz", {

get_iterable_ds <- iterable_dataset(
"iterable_ds",
initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) {
self$len <- len
self$x <- torch::torch_randn(size = c(len, x_size))
self$y <- torch::torch_randn(size = c(len, y_size))
},
.iter = function() {
i <- 0
function() {
i <<- i + 1

if (i > self$len) {
return(coro::exhausted())
}

list(
x = self$x[i,..],
y = self$y[i,..]
)
}
}
)

ds <- get_iterable_ds()
dl <- dataloader(ds, batch_size = 32)
expect_equal(length(coro::collect(dl)), 4)

})

0 comments on commit 4c93f8f

Please sign in to comment.