Skip to content

Commit

Permalink
Optimizer fixes (#1100)
Browse files Browse the repository at this point in the history
* Adds cosine annealing and make sure we correctly recover state dicts from learning rate schedulers.

* test

* use integer step size.
  • Loading branch information
dfalbel authored Sep 11, 2023
1 parent cccf2ce commit 1610838
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 3 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ export(local_enable_grad)
export(local_no_grad)
export(local_torch_manual_seed)
export(loop)
export(lr_cosine_annealing)
export(lr_lambda)
export(lr_multiplicative)
export(lr_one_cycle)
Expand Down
51 changes: 49 additions & 2 deletions R/optim-lr_scheduler.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,18 @@ LRScheduler <- R6::R6Class(

self$last_epoch <- last_epoch
self$verbose <- verbose

self$.step_count <- 0L
self$step()
},
state_dict = function() {
dict <- as.list(self)
dict <- dict[[-which(names(dict) == "optimizer")]]
dict <- dict[-which(names(dict) == "optimizer")]

# we also drop functions and environments
dict <- dict[!sapply(dict, is.function)]
dict <- dict[!sapply(dict, is.environment)]

dict
},
load_state_dict = function(state_dict) {
Expand All @@ -69,11 +76,12 @@ LRScheduler <- R6::R6Class(
}
},
step = function() {
self$.step_count <- self$.step_count + 1L
self$last_epoch <- self$last_epoch + 1
values <- self$get_lr()

for (i in seq_along(self$optimizer$param_groups)) {
self$optimizer$param_groups[[i]]$lr <- values[i]
self$optimizer$param_groups[[i]]$lr <- values[[i]]
self$print_lr(self$verbose, i, self$optimizer$param_groups[[i]]$lr)
}

Expand Down Expand Up @@ -732,3 +740,42 @@ lr_reduce_on_plateau <- lr_scheduler(
}

)

#' Set the learning rate of each parameter group using a cosine annealing schedule
#'
#' @param T_max Maximum number of iterations
#' @param eta_min Minimum learning rate. Default: 0.
#' @param last_epoch The index of the last epoch
#'
#' @inheritParams lr_reduce_on_plateau
#' @export
lr_cosine_annealing <- lr_scheduler(
"lr_cosine_annealing",
initialize = function(optimizer, T_max, eta_min=0, last_epoch=-1, verbose=FALSE) {
self$T_max <- T_max
self$eta_min <- eta_min
super$initialize(optimizer, last_epoch, verbose)
},
get_lr = function() {
if (self$last_epoch == 0) {
return(lapply(self$optimizer$param_groups, function(x) x[["lr"]]))
} else if (self$.step_count == 1 && self$last_epoch > 0) {
lapply(self$base_lrs, function(group, base_lr) {
self$eta_min +
(base_lr - self$eta_min) *
(1 + cos(self$last_epoch * pi / self$T_max)) /
2
})
} else if ((self$last_epoch -1 - self$T_max) %% (2 * self$T_max) == 0) {
map2(self$optimizer$param_groups, self$base_lrs, function(group, base_lr) {
group[["lr"]] + (base_lr - self$eta_min) * (1 - cos(pi / self$T_max)) / 2
})
} else {
lapply(self$optimizer$param_groups, function(group) {
(1 + cos(pi * self$last_epoch / self$T_max)) /
(1 + cos(pi * (self$last_epoch - 1) / self$T_max)) *
(group[['lr']] - self$eta_min) + self$eta_min
})
}
}
)
9 changes: 9 additions & 0 deletions R/optim.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ Optimizer <- R6::R6Class(
self$state$set(parameters[[index]], value)
}

# we must also update the param groups
for (i in seq_along(state_dict$param_groups)) {
group <- state_dict$param_groups[[i]]
for (nm in names(group)) {
if (nm == "params") next
self$param_groups[[i]][[nm]] <- group[[nm]]
}
}

invisible(self)
}
),
Expand Down
29 changes: 29 additions & 0 deletions man/lr_cosine_annealing.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 49 additions & 0 deletions tests/testthat/test-optim-lr_scheduler.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,52 @@ test_that("lr_reduce_on_plateau", {
}
expect_equal(o$param_groups[[1]]$lr, 0.09) # matched to pytorch
})

test_that("lr_cosine_annealing", {

m <- nn_linear(10, 10)
o <- optim_adam(params = m$parameters, lr = 0.1)
scheduler <- lr_cosine_annealing(o, T_max = 1, eta_min = 1e-5)

expect_equal(o$param_groups[[1]]$lr, 0.1)


scheduler$step()
expect_equal(o$param_groups[[1]]$lr, 1e-5)
scheduler$step()
expect_equal(o$param_groups[[1]]$lr, 0.1)
scheduler$step()
expect_equal(o$param_groups[[1]]$lr, 1e-5)

})

test_that("state dict works", {

m <- nn_linear(10, 10)
o <- optim_sgd(params = m$parameters, lr = 1)
scheduler <- lr_step(optimizer = o, step_size = 1)

expect_equal(o$param_groups[[1]]$lr, 1)
scheduler$step()
expect_equal(o$param_groups[[1]]$lr, 0.1)

dict <- scheduler$state_dict()
opt_dict <- o$state_dict()

scheduler$step()
expect_equal(o$param_groups[[1]]$lr, 0.01)

o <- optim_sgd(params = m$parameters, lr = 1)
expect_equal(o$param_groups[[1]]$lr, 1)

o$load_state_dict(opt_dict)
expect_equal(o$param_groups[[1]]$lr, 0.1)

scheduler <- lr_step(optimizer = o, step_size = 2) # use a different value
scheduler$load_state_dict(dict)

scheduler$step()

expect_equal(o$param_groups[[1]]$lr, 0.01)

})
4 changes: 3 additions & 1 deletion tests/testthat/test-optim-sgd.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ test_that("copy state between optimizers corecctly", {
with_no_grad({
y <- torch_empty(1, requires_grad = TRUE)$copy_(x)
})
opt2 <- optim_adam(y, lr = 0.1)
opt2 <- optim_adam(y, lr = 1) # use a different LR to make sure it is recovered
opt2$load_state_dict(opt$state_dict())
expect_equal(opt2$param_groups[[1]]$lr, 0.1)

(2*y)$backward()
opt2$step()
opt2$state_dict()
Expand Down

0 comments on commit 1610838

Please sign in to comment.