-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement luz_callback_validation_check #56
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#' @include callbacks.R | ||
NULL | ||
|
||
#' Validation Check | ||
#' | ||
#' Check validation loop before fitting model. | ||
#' | ||
#' @param batches Number of validation batches to check. Default is 2. | ||
#' | ||
#' @note Usually the training loop is much longer than the validation | ||
#' loop and issues with the validation loop aren't encountered until after | ||
#' a long training runtime. This callback runs the validation loop first on | ||
#' `batches` number of batches and then proceeds onto the standard | ||
#' training process. | ||
#' | ||
#' @note Printing can be disabled by passing `verbose = FALSE` to | ||
#' [fit.luz_module_generator()]. | ||
#' | ||
#' @family luz_callbacks | ||
#' | ||
#' @returns | ||
#' A `luz_callback`. | ||
#' | ||
#' @export | ||
luz_callback_validation_check <- luz_callback( | ||
"validation_check_callback", | ||
initialize = function(batches = 2) { | ||
if (!rlang::is_scalar_integerish(batches)) { | ||
rlang::abort("`batches` must be a single integer value.") | ||
} | ||
self$batches <- batches | ||
}, | ||
on_fit_begin = function() { | ||
if (is.null(ctx$valid_data)) return() | ||
if (self$batches <= 0) return() | ||
|
||
ctx$model$eval() | ||
ctx$training <- FALSE | ||
|
||
self$initialize_progress_bar() | ||
|
||
i <- 0 | ||
torch::with_no_grad({ | ||
coro::loop(for (batch in ctx$valid_data) { | ||
self$validate_one_batch(batch) | ||
self$tick_progress_bar(self$loss) | ||
i <- i + 1 | ||
if (i >= self$batches) break() | ||
}) | ||
}) | ||
}, | ||
validate_one_batch = function(batch) { | ||
input <- list(batch[[1]]) | ||
target <- batch[[2]] | ||
pred <- do.call(ctx$model, input) | ||
self$loss <- ctx$model$loss(pred, target) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in general we would want to do the full validation step because the errors could be in any of the callbacks etc. But we would need to take care of the side effects that this might cause. We would need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about this problem. It feels to me that the safest way would be to I think this is possible if the first thing we do in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I've kind of gone in circles here. We want to call the validation callbacks so it is a complete check of the validation loop, but I was worried about any changes in state this might have. I did consider using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhat related question, when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, they are called in that order: default callbacks then user callbacks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did actually think about calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking again, there could still be some side effects, eg: the callbacks passed by the user can have side effects outside of the R session (maybe writing to a file or something like this). So maybe we want to call fit again, only with the default callbacks + the one that breaks the training loop. This is not completely ideal, because still there would be callbacks that could fail in the 'real' pass. Bu sounds like enough, I guess. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with only calling the default callbacks. My original reason for avoiding callbacks was for loggers and other things written to disk. But if we only run default callbacks we can avoid this issue. The function docs can just point out that user callbacks aren't validated. |
||
}, | ||
initialize_progress_bar = function() { | ||
format <- "Validation check: :current/:total [:bar] - Loss: :loss" | ||
self$pb <- progress::progress_bar$new( | ||
force = getOption("luz.force_progress_bar", FALSE), | ||
show_after = 0, | ||
format = format, | ||
total = self$batches, | ||
clear = FALSE | ||
) | ||
}, | ||
tick_progress_bar = function(token) { | ||
if (ctx$verbose) { | ||
loss <- format(round(as.numeric(token), digits = 4), nsmall = 4) | ||
self$pb$tick(tokens = list(loss = loss)) | ||
} | ||
} | ||
) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to extract this out into a function:
luz/R/callbacks.R
Lines 298 to 300 in 6e0bb77
And reuse here so we make sure that the same changes are always set?