Skip to content

Commit

Permalink
Provide a way to set gradients to none (#1195)
Browse files Browse the repository at this point in the history
* Provide a way to set gradients to none

* Use a different implementation
  • Loading branch information
dfalbel authored Sep 20, 2024
1 parent 030cda0 commit ab71bda
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 22 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ cpp_set_cuda_allocator_allocator_thresholds <- function(reserved_rate, allocated
invisible(.Call(`_torch_cpp_set_cuda_allocator_allocator_thresholds`, reserved_rate, allocated_rate, allocated_reserved_rate))
}

cpp_autograd_zero_grad <- function(x) {
invisible(.Call(`_torch_cpp_autograd_zero_grad`, x))
cpp_autograd_zero_grad <- function(x, set_to_none) {
invisible(.Call(`_torch_cpp_autograd_zero_grad`, x, set_to_none))
}

cpp_backends_mkldnn_is_available <- function() {
Expand Down
4 changes: 2 additions & 2 deletions R/optim.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ Optimizer <- R6::R6Class(

self$param_groups <- append(self$param_groups, list(param_group))
},
zero_grad = function() {
zero_grad = function(set_to_none = FALSE) {
for (group in self$param_groups) {
cpp_autograd_zero_grad(group$params)
cpp_autograd_zero_grad(group$params, set_to_none)
}
},
state_dict = function() {
Expand Down
6 changes: 3 additions & 3 deletions inst/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2405,10 +2405,10 @@ HOST_API void* lantern_IntArrayRef_get (void* x)
return ret;
}

LANTERN_API void (LANTERN_PTR _lantern_autograd_zero_grad) (void * self);
HOST_API void lantern_autograd_zero_grad (void * self) {
LANTERN_API void (LANTERN_PTR _lantern_autograd_zero_grad) (void * self, bool set_to_none);
HOST_API void lantern_autograd_zero_grad (void * self, bool set_to_none) {
LANTERN_CHECK_LOADED
_lantern_autograd_zero_grad(self);
_lantern_autograd_zero_grad(self, set_to_none);
LANTERN_HOST_HANDLER;
}

Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,13 @@ BEGIN_RCPP
END_RCPP
}
// cpp_autograd_zero_grad
void cpp_autograd_zero_grad(torch::TensorList x);
RcppExport SEXP _torch_cpp_autograd_zero_grad(SEXP xSEXP) {
void cpp_autograd_zero_grad(torch::TensorList x, bool set_to_none);
RcppExport SEXP _torch_cpp_autograd_zero_grad(SEXP xSEXP, SEXP set_to_noneSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< torch::TensorList >::type x(xSEXP);
cpp_autograd_zero_grad(x);
Rcpp::traits::input_parameter< bool >::type set_to_none(set_to_noneSEXP);
cpp_autograd_zero_grad(x, set_to_none);
return R_NilValue;
END_RCPP
}
Expand Down Expand Up @@ -45647,7 +45648,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_torch_cpp_autograd_grad", (DL_FUNC) &_torch_cpp_autograd_grad, 6},
{"_torch_cpp_set_lantern_allocator", (DL_FUNC) &_torch_cpp_set_lantern_allocator, 1},
{"_torch_cpp_set_cuda_allocator_allocator_thresholds", (DL_FUNC) &_torch_cpp_set_cuda_allocator_allocator_thresholds, 3},
{"_torch_cpp_autograd_zero_grad", (DL_FUNC) &_torch_cpp_autograd_zero_grad, 1},
{"_torch_cpp_autograd_zero_grad", (DL_FUNC) &_torch_cpp_autograd_zero_grad, 2},
{"_torch_cpp_backends_mkldnn_is_available", (DL_FUNC) &_torch_cpp_backends_mkldnn_is_available, 0},
{"_torch_cpp_backends_mkl_is_available", (DL_FUNC) &_torch_cpp_backends_mkl_is_available, 0},
{"_torch_cpp_backends_openmp_is_available", (DL_FUNC) &_torch_cpp_backends_openmp_is_available, 0},
Expand Down
4 changes: 2 additions & 2 deletions src/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,6 @@ void cpp_set_cuda_allocator_allocator_thresholds (double reserved_rate, double a
}

// [[Rcpp::export]]
void cpp_autograd_zero_grad (torch::TensorList x) {
lantern_autograd_zero_grad(x.get());
void cpp_autograd_zero_grad (torch::TensorList x, bool set_to_none) {
lantern_autograd_zero_grad(x.get(), set_to_none);
}
6 changes: 3 additions & 3 deletions src/lantern/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2405,10 +2405,10 @@ HOST_API void* lantern_IntArrayRef_get (void* x)
return ret;
}

LANTERN_API void (LANTERN_PTR _lantern_autograd_zero_grad) (void * self);
HOST_API void lantern_autograd_zero_grad (void * self) {
LANTERN_API void (LANTERN_PTR _lantern_autograd_zero_grad) (void * self, bool set_to_none);
HOST_API void lantern_autograd_zero_grad (void * self, bool set_to_none) {
LANTERN_CHECK_LOADED
_lantern_autograd_zero_grad(self);
_lantern_autograd_zero_grad(self, set_to_none);
LANTERN_HOST_HANDLER;
}

Expand Down
16 changes: 10 additions & 6 deletions src/lantern/src/Autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,18 @@ void *_lantern_Edge_function(void *self) {
LANTERN_FUNCTION_END
}

void _lantern_autograd_zero_grad (void * self) {
void _lantern_autograd_zero_grad (void * self, bool set_to_none) {
LANTERN_FUNCTION_START
auto list = from_raw::TensorList(self);
for (auto &t : list) {
auto grad = t.grad();
if (grad.defined()) {
grad.zero_();
}
for (auto &p : list) {
if (p.mutable_grad().defined()) {
p.mutable_grad().detach_();
if (set_to_none) {
p.mutable_grad().reset();
} else {
p.mutable_grad().zero_();
}
}
}
LANTERN_FUNCTION_END_VOID
}
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-optim-sgd.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,16 @@ test_that("copy state between optimizers corecctly", {
opt$zero_grad()

expect_equal_to_tensor(x, y)
})

test_that("zero_grad set_to_none", {
# start with a tensor and make one step in the optimize
x <- torch_tensor(1, requires_grad = TRUE)

opt <- optim_sgd(x, lr = 0.1)
(2*x)$backward()
opt$step()
opt$zero_grad(set_to_none = TRUE)

expect_true(is_undefined_tensor(x$grad))
})

0 comments on commit ab71bda

Please sign in to comment.