Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial work on tracing callback #57

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions R/callbacks-trace.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

#' Traces the model for training and validation to get a few speedups.
#'
#'
#' @export
luz_callback_trace <- luz_callback(
"trace_callback",
initialize = function(check_train = TRUE, check_valid = TRUE) {
self$traced_models <- list(train = NULL, valid = NULL)
self$check_train <- check_train
self$check_valid <- check_valid
},
on_train_batch_after_pred = function() {

if (ctx$epoch > 1 || ctx$iter > 1) return()

# Traces the model for training
if (is.null(self$traced_models$train)) {
traced <- torch::jit_trace_module(
ctx$model,
forward = ctx$input,
loss = list(ctx$pred, ctx$target)
)
} else {
traced <- self$traced_models$train
rlang::warn("The same traced model is going to be used both optimizers and this might lead to unexpected results.")
}

if (self$check_train) {

equal <- all.equal(do.call(traced, list(ctx$input)), ctx$pred)
if (!equal) {
rlang::abort(
"Traced model didn't produce identical results when compared to the original model during training.",
"Are you sure that tracing is producing the correct output? If yes, disable this error with `check_train=FALSE`."
)
}
}

self$traced_models$train <- traced
ctx$model <- self$traced_models$train
},
on_train_begin = function() {
self$model <- ctx$model
if (ctx$epoch <= 1) return()
ctx$model <- self$traced_models$train
},
on_train_end = function() {
ctx$model <- self$model
ctx$model$load_state_dict(self$traced_models$train$state_dict())
},
on_valid_batch_after_pred = function() {
if (ctx$epoch > 1 || ctx$iter > 1) return()

# Traces the model for validation
if (is.null(self$traced_models$valid)) {
traced <- torch::jit_trace_module(
ctx$model,
forward = ctx$input,
loss = list(ctx$pred, ctx$target)
)
} else {
traced <- self$traced_models$valid
rlang::warn("The same traced model is going to be used both optimizers and this might lead to unexpected results.")
}

if (self$check_valid) {
equal <- all.equal(do.call(traced, list(ctx$input)), ctx$pred)
if (!equal) {
rlang::abort(
"Traced model didn't produce identical results when compared to the original model during validation",
"Are you sure that tracing is producing the correct output? If yes, disable this error with `check_valid=FALSE`."
)
}
}

self$traced_models$valid <- traced
ctx$model <- self$traced_models$valid
},
on_valid_begin = function() {
self$model <- ctx$model
if (ctx$epoch <= 1) return()
ctx$model <- self$traced_models$valid
},
on_valid_end = function() {
ctx$model <- self$model
}
)

all.equal.torch_tensor <- function(target, current, ...) {
torch::torch_allclose(target, current, ...)
}
6 changes: 3 additions & 3 deletions R/module.R
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,12 @@ default_step <- function(ctx) {

fit_one_batch <-function(ctx) {
for (nm in names(ctx$optimizers)) {
ctx$pred <- do.call(ctx$model, list(ctx$input))
ctx$call_callbacks("on_train_batch_after_pred")

ctx$opt <- ctx$optimizers[[nm]]
ctx$opt_name <- nm

ctx$pred <- do.call(ctx$model, list(ctx$input))
ctx$call_callbacks("on_train_batch_after_pred")

ctx$loss_grad <- ctx$model$loss(ctx$pred, ctx$target)
ctx$loss[[ctx$opt_name]] <- ctx$loss_grad$detach()

Expand Down
101 changes: 101 additions & 0 deletions tests/testthat/test-callbacks-trace.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
test_that("trace callback for minimum model", {

model <- torch::nn_linear
model <- setup(model,
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam)
model <- set_hparams(model, in_features = 10, out_features = 10)


ds <- torch::tensor_dataset(x = torch::torch_randn(10, 10), y = torch::torch_randn(10, 10))
dl <- torch::dataloader(ds, batch_size = 2)

fitted <- fit(model, dl, epochs = 5, callbacks = list(luz_callback_trace()), verbose = FALSE)

expect_s3_class(fitted, "luz_module_fitted")
})

test_that("works by disabling the checks", {
# this model is non-deterministic for the training loop because it includes
# a dropout, so it should if check_train = TRUE

model <- torch::nn_module(
initialize = function() {
self$linear <- torch::nn_linear(10, 10)
self$dropout <- torch::nn_dropout()
},
forward = function(x) {
x %>% self$linear() %>% self$dropout()
}
)

model <- setup(model,
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam)


ds <- torch::tensor_dataset(x = torch::torch_randn(10, 10), y = torch::torch_randn(10, 10))
dl <- torch::dataloader(ds, batch_size = 2)

expect_error(
fitted <- fit(model, dl, epochs = 5, valid_data = dl, callbacks = list(luz_callback_trace()), verbose = FALSE),
regexp = "Traced model didn't"
)

fitted <- fit(model, dl, epochs = 5, valid_data = dl, callbacks = list(luz_callback_trace(check_train = FALSE)), verbose = FALSE)
expect_s3_class(fitted, "luz_module_fitted")
})

test_that("works by disabling validation checks too", {

model <- torch::nn_module(
initialize = function() {
self$linear <- torch::nn_linear(10, 10)
},
forward = function(x) {
self$linear(x) + torch::torch_randn(1)
}
)

model <- setup(model,
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam)


ds <- torch::tensor_dataset(x = torch::torch_randn(10, 10), y = torch::torch_randn(10, 10))
dl <- torch::dataloader(ds, batch_size = 2)

expect_error(
fitted <- fit(model, dl, epochs = 5, valid_data = dl, callbacks = list(luz_callback_trace()), verbose = FALSE),
regexp = "Traced model didn't"
)

expect_error(
fitted <- fit(model, dl, epochs = 5, valid_data = dl, callbacks = list(luz_callback_trace(check_train = FALSE)), verbose = FALSE),
regexp = "Traced model didn't"
)

fitted <- fit(model, dl, epochs = 5, valid_data = dl, callbacks = list(luz_callback_trace(check_train = FALSE, check_valid = FALSE)), verbose = FALSE)
expect_s3_class(fitted, "luz_module_fitted")
})

test_that("parameters are correctly updated", {

model <- torch::nn_linear
model <- setup(model,
loss = torch::nn_mse_loss(),
optimizer = torch::optim_sgd)
model <- set_hparams(model, in_features = 2, out_features = 1)
model <- set_opt_hparams(model, lr = 0.1)


x <- torch::torch_randn(10, 2)
y <- torch::torch_sum(x, 2, keepdim = TRUE)

ds <- torch::tensor_dataset(x = x, y = y)
dl <- torch::dataloader(ds, batch_size = 5)

fitted <- fit(model, dl, epochs = 20, verbose = FALSE, callbacks = luz_callback_trace())
expect_lt(tail(fitted$ctx$get_metrics_df()$value, 1), 0.1)

})