Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement luz_callback_validation_check #56

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Collate:
'accelerator.R'
'utils.R'
'callbacks.R'
'callback-validation-check.R'
'callbacks-interrupt.R'
'callbacks-profile.R'
'context.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export(luz_callback_model_checkpoint)
export(luz_callback_profile)
export(luz_callback_progress)
export(luz_callback_train_valid)
export(luz_callback_validation_check)
export(luz_load)
export(luz_load_model_weights)
export(luz_metric)
Expand Down
74 changes: 74 additions & 0 deletions R/callback-validation-check.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#' @include callbacks.R
NULL

#' Validation Check
#'
#' Check validation loop before fitting model.
#'
#' @param batches Number of validation batches to check. Default is 2.
#'
#' @note Usually the training loop is much longer than the validation
#' loop and issues with the validation loop aren't encountered until after
#' a long training runtime. This callback runs the validation loop first on
#' `batches` number of batches and then proceeds onto the standard
#' training process.
#'
#' @note Printing can be disabled by passing `verbose = FALSE` to
#' [fit.luz_module_generator()].
#'
#' @family luz_callbacks
#'
#' @returns
#' A `luz_callback`.
#'
#' @export
luz_callback_validation_check <- luz_callback(
"validation_check_callback",
initialize = function(batches = 2) {
if (!rlang::is_scalar_integerish(batches)) {
rlang::abort("`batches` must be a single integer value.")
}
self$batches <- batches
},
on_fit_begin = function() {
if (is.null(ctx$valid_data)) return()
if (self$batches <= 0) return()

ctx$model$eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to extract this out into a function:

luz/R/callbacks.R

Lines 298 to 300 in 6e0bb77

ctx$model$eval()
ctx$training <- FALSE
ctx$loss <- list()

And reuse here so we make sure that the same changes are always set?

ctx$training <- FALSE

self$initialize_progress_bar()

i <- 0
torch::with_no_grad({
coro::loop(for (batch in ctx$valid_data) {
self$validate_one_batch(batch)
self$tick_progress_bar(self$loss)
i <- i + 1
if (i >= self$batches) break()
})
})
},
validate_one_batch = function(batch) {
input <- list(batch[[1]])
target <- batch[[2]]
pred <- do.call(ctx$model, input)
self$loss <- ctx$model$loss(pred, target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general we would want to do the full validation step because the errors could be in any of the callbacks etc. But we would need to take care of the side effects that this might cause.

We would need to call valid_one_step() and then make sure we can reset the state. Not sure yet what would be the best way to do it though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this problem. It feels to me that the safest way would be to on_fit_begin() we call fit again and we add a callback that breaks the training loop after batches steps for both training and validation. This way, no side effects would interfere in the actual training loop but we still run the full loop which would detect the other possible bugs.

I think this is possible if the first thing we do in the ctx object is to save a list with all arguments that were passed to
fit, before we do any kind of manipulation (like we do for callbacks).
To avoid the infinite recursion we could check ctx$callbacks to check if the callback that breaks the loop is present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I've kind of gone in circles here. We want to call the validation callbacks so it is a complete check of the validation loop, but I was worried about any changes in state this might have. I did consider using valid_one_batch() but at the time decided against it for the above reasons.

Copy link
Contributor Author

@mattwarkentin mattwarkentin Jul 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat related question, when ctx$call_callbacks("on_..._...") is called, if there are multiple callbacks with available methods for the breakpoint, what is the order they are called in? Default callbacks first, user-supplied callbacks second?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are called in that order: default callbacks then user callbacks.
I think that if we call fit again, there would be no interference, the only difference is that it would also test the training loop. But we could also skip it anyway...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did actually think about calling fit() again inside on_fit_begin() but I decided against. But you're right, it would be a good way to check both the training and validation loops before committing to a full fit.

Copy link
Member

@dfalbel dfalbel Jul 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking again, there could still be some side effects, eg: the callbacks passed by the user can have side effects outside of the R session (maybe writing to a file or something like this). So maybe we want to call fit again, only with the default callbacks + the one that breaks the training loop.

This is not completely ideal, because still there would be callbacks that could fail in the 'real' pass. Bu sounds like enough, I guess.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with only calling the default callbacks. My original reason for avoiding callbacks was for loggers and other things written to disk. But if we only run default callbacks we can avoid this issue. The function docs can just point out that user callbacks aren't validated.

},
initialize_progress_bar = function() {
format <- "Validation check: :current/:total [:bar] - Loss: :loss"
self$pb <- progress::progress_bar$new(
force = getOption("luz.force_progress_bar", FALSE),
show_after = 0,
format = format,
total = self$batches,
clear = FALSE
)
},
tick_progress_bar = function(token) {
if (ctx$verbose) {
loss <- format(round(as.numeric(token), digits = 4), nsmall = 4)
self$pb$tick(tokens = list(loss = loss))
}
}
)
3 changes: 2 additions & 1 deletion R/module.R
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,9 @@ clean_context <- function(ctx) {
"pred",
"opt",
"opt_name",
"data",
"handlers",
"data",
"train_data",
"valid_data",
"loss",
"input",
Expand Down
3 changes: 2 additions & 1 deletion man/ctx.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/luz_callback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_csv_logger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_early_stopping.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_interrupt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_lr_scheduler.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_metrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_model_checkpoint.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_profile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_progress.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_train_valid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions man/luz_callback_validation_check.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.