Skip to content

Commit

Permalink
Fix loading of scalar tensors. (#1234)
Browse files Browse the repository at this point in the history
* Fix loading of scalar tensors.

* Add regression test
  • Loading branch information
dfalbel authored Jan 16, 2025
1 parent eb05688 commit b31a5fc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
3 changes: 3 additions & 0 deletions R/tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,9 @@ tensor_to_complex <- function(x) {
#'
#' @export
torch_tensor_from_buffer <- function(buffer, shape, dtype = "float") {
if (!is.integer(shape)) {
shape <- as.integer(shape)
}
cpp_tensor_from_buffer(buffer, shape, list(dtype=dtype))
}

Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test-save.R
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,11 @@ test_that("can save a complex tensor", {

expect_true(torch_allclose(x$real, z$real))
expect_true(torch_allclose(x$imag, z$imag))
})

test_that("can load a scalar tensor", {
x <- torch_scalar_tensor(1)
k <- torch_serialize(x)
y <- torch_load(k)
expect_true(torch_allclose(x, y))
})

0 comments on commit b31a5fc

Please sign in to comment.