Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/mlverse/torch
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Sep 1, 2023
2 parents 522501e + f957d60 commit 259025b
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 8 deletions.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ S3method(as.matrix,torch_tensor)
S3method(as.numeric,torch_tensor)
S3method(as_array,torch_tensor)
S3method(as_iterator,dataloader)
S3method(as_iterator,iterable_dataset)
S3method(as_iterator,utils_sampler)
S3method(asin,torch_tensor)
S3method(atan,torch_tensor)
Expand All @@ -83,6 +84,7 @@ S3method(expm1,torch_tensor)
S3method(floor,torch_tensor)
S3method(length,dataloader)
S3method(length,dataset)
S3method(length,iterable_dataset)
S3method(length,nn_module_list)
S3method(length,nn_sequential)
S3method(length,torch_tensor)
Expand All @@ -100,6 +102,7 @@ S3method(nn_prune_head,nn_sequential)
S3method(print,R7)
S3method(print,cuda_memory_stats)
S3method(print,dataset_generator)
S3method(print,iterable_dataset_generator)
S3method(print,nn_module)
S3method(print,script_function)
S3method(print,script_method)
Expand Down Expand Up @@ -178,6 +181,7 @@ export(is_torch_layout)
export(is_torch_memory_format)
export(is_torch_qscheme)
export(is_undefined_tensor)
export(iterable_dataset)
export(jit_compile)
export(jit_load)
export(jit_ops)
Expand Down Expand Up @@ -526,9 +530,11 @@ export(torch_ceil)
export(torch_celu)
export(torch_celu_)
export(torch_cfloat)
export(torch_cfloat128)
export(torch_cfloat32)
export(torch_cfloat64)
export(torch_chain_matmul)
export(torch_chalf)
export(torch_channel_shuffle)
export(torch_channels_last_format)
export(torch_cholesky)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Added support for CUDA 11.8. (#1089)
- Fixed segfault caused by comparing a `dtype` with a `NULL`. (#1090)
- Fixed incorrect naming of complex data type names, such as `torch_cfloat64`. (#1091)
- Added support for iterable datasets. (#1095)

# torch 0.11.0

Expand Down
30 changes: 28 additions & 2 deletions R/utils-data-dataloader.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,15 @@ DataLoader <- R6::R6Class(

if (is_map_dataset(dataset)) {
self$.dataset_kind <- "map"
} else if (is_iterable_dataset(dataset)) {
self$.dataset_kind <- "iterable"
} else {
cli::cli_abort("Unknown dataset type with class {.cls {class(dataset)}}")
}

if (is.null(sampler)) {
if (self$.dataset_kind == "iterable") {
# TODO
sampler <- InfiniteSampler()
} else {
if (shuffle) {
sampler <- RandomSampler(dataset, generator = generator)
Expand Down Expand Up @@ -200,13 +204,26 @@ DataLoader <- R6::R6Class(
}

MultiProcessingDataLoaderIter$new(self)
} else if (self$.dataset_kind == "iterable") {
if (self$num_workers == 0) {
return(SingleProcessDataLoaderIter$new(self))
}
cli::cli_abort("Multi-process dataloader not implemented yet for Iterable datasets.")
} else {
not_implemented_error()
}
},
.length = function() {
if (self$.dataset_kind == "iterable") {
not_implemented_error()
l <- length(self$dataset)

if (is.na(l)) return(l)

if (self$drop_last) {
return(l %/% self$batch_size)
} else {
return(as.integer(ceiling(l / self$batch_size)))
}
} else {
length(self$.index_sampler)
}
Expand Down Expand Up @@ -283,6 +300,13 @@ SingleProcessDataLoaderIter <- R6::R6Class(
self$.collate_fn,
self$.drop_last
)
} else if (self$.dataset_kind == "iterable") {
self$.dataset_fetcher <- IterableDatasetFetcher$new(
self$.dataset,
self$.auto_collation,
self$.collate_fn,
self$.drop_last
)
} else {
not_implemented_error()
}
Expand All @@ -294,7 +318,9 @@ SingleProcessDataLoaderIter <- R6::R6Class(
return(coro::exhausted())
}

# data can be exhausted in iterable datasets
data <- self$.dataset_fetcher$fetch(index)

if (self$.pin_memory) {
# TODO
}
Expand Down
7 changes: 7 additions & 0 deletions R/utils-data-fetcher.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ IterableDatasetFetcher <- R6::R6Class(
d <- self$dataset_iter()

if (is_exhausted(d)) {
if (self$drop_last || i == 1) {
return(coro::exhausted())
}

# we drop the null values in that list.
data <- data[seq_len(i-1L)]
break
}

Expand All @@ -44,6 +50,7 @@ IterableDatasetFetcher <- R6::R6Class(
} else {
data <- self$dataset_iter()
}

self$collate_fn(data)
}
)
Expand Down
13 changes: 13 additions & 0 deletions R/utils-data-sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ BatchSampler <- sampler(
}
)

InfiniteSampler <- sampler(
"infinite_sampler",
initialize = function() {},
.iter = function() {
function() {
TRUE
}
},
.length = function() {
Inf
}
)

#' @export
as_iterator.utils_sampler <- function(x) {
it <- x$.iter()
Expand Down
73 changes: 73 additions & 0 deletions R/utils-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,32 @@ Dataset <- R6::R6Class(
)
)

IterableDataset <- R6::R6Class(
classname = "iterable_dataset",
lock_objects = FALSE,
public = list(
.iter = function() {
not_implemented_error()
},
.length = function() {
NA_integer_
}
)
)

is_map_dataset <- function(x) {
inherits(x, "dataset")
}

is_iterable_dataset <- function(x) {
inherits(x, "iterable_dataset")
}

#' @export
as_iterator.iterable_dataset <- function(x) {
x$.iter()
}

get_init <- function(x) {
if (!is.null(x$public_methods$initialize)) {
return(x$public_methods$initialize)
Expand Down Expand Up @@ -107,12 +129,58 @@ dataset <- function(name = NULL, inherit = Dataset, ...,
)
}


#' Creates an iterable dataset
#'
#' @inheritParams dataset
#' @examples
#' ids <- iterable_dataset(
#' name = "hello",
#' initialize = function(n = 5) {
#' self$n <- n
#' self$i <- 0
#' },
#' .iter = function() {
#' i <- 0
#' function() {
#' i <<- i + 1
#' if (i > self$n) {
#' coro::exhausted()
#' } else {
#' i
#' }
#' }
#' }
#' )
#' coro::collect(ids()$.iter())
#' @export
iterable_dataset <- function(name, inherit = IterableDataset, ...,
private = NULL, active = NULL,
parent_env = parent.frame()) {
create_class(
name = name,
inherit = inherit,
...,
private = private,
active = active,
parent_env = parent_env,
attr_name = "Dataset",
constructor_class = "iterable_dataset_generator"
)
}

#' @export
print.dataset_generator <- function(x, ...) {
cli::cat_line("<dataset_generator>")
print(attr(x, "Dataset"))
}

#' @export
print.iterable_dataset_generator <- function(x, ...) {
cli::cat_line("<iterable_dataset_generator>")
print(attr(x, "IterableDataset"))
}

#' @export
`[.dataset` <- function(x, y) {
y <- as.integer(y)
Expand All @@ -136,6 +204,11 @@ length.dataset <- function(x) {
x$.length()
}

#' @export
length.iterable_dataset <- function(x) {
x$.length()
}

#' Dataset wrapping tensors.
#'
#' Each sample will be retrieved by indexing tensors along the first dimension.
Expand Down
57 changes: 57 additions & 0 deletions man/iterable_dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 10 additions & 4 deletions man/torch_dtype.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions src/lantern/src/Dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ LANTERN_DTYPE_FUN(byte, kByte)

void* _lantern_Dtype_from_string (void* dtype_str) {
LANTERN_FUNCTION_START

if (!dtype_str) {
throw std::runtime_error("Error dtype can't be NULL");
}

auto str = from_raw::string(dtype_str);
auto dtype = [&str] () {
if (str == "float" || str == "float32") {
Expand Down
7 changes: 5 additions & 2 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,19 @@ std::thread::id main_thread_id() noexcept {
// [[Rcpp::export]]
Rcpp::List transpose2(Rcpp::List x) {
auto templ = Rcpp::as<Rcpp::List>(x[0]);
auto num_elements = templ.length();
const auto num_elements = templ.length();

auto size = x.length();
const auto size = x.length();
std::vector<Rcpp::List> out;

for (auto i = 0; i < num_elements; i++) {
out.push_back(Rcpp::List(size));
}

for (size_t j = 0; j < size; j++) {
if (Rf_isNull(x[j])) {
Rcpp::stop("NULL is not allowed. Expected a list.");
}
auto el = Rcpp::as<Rcpp::List>(x[j]);
for (auto i = 0; i < num_elements; i++) {
out[i][j] = el[i];
Expand Down
Loading

0 comments on commit 259025b

Please sign in to comment.