From 3f027afe2f9882f7606304dd4bf24c58ea3a8034 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 26 Oct 2024 08:42:51 +0200 Subject: [PATCH 1/2] remove old note and unused class --- DESCRIPTION | 2 +- R/script_module.R | 75 ----------------------------------------------- R/trace.R | 1 - 3 files changed, 1 insertion(+), 77 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index ecab8f937c..52d59064b9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -44,7 +44,7 @@ Imports: desc, safetensors (>= 0.1.1), jsonlite -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Roxygen: list(markdown = TRUE) Suggests: testthat (>= 3.0.0), diff --git a/R/script_module.R b/R/script_module.R index 3a954a9459..0f64c4295a 100644 --- a/R/script_module.R +++ b/R/script_module.R @@ -1,78 +1,3 @@ -ScriptModule <- R7Class( - "torch_script_module", - public = list( - ptr = NULL, - initialize = function(ptr) { - ptr - }, - train = function(mode = TRUE) { - cpp_jit_script_module_train(self, mode) - invisible(self) - }, - register_parameter = function(name, param) { - cpp_jit_script_module_register_parameter(self, name, param, FALSE) - invisible(self) - }, - register_buffer = function(name, tensor, persistent = TRUE) { - if (!persistent) { - runtime_error("ScriptModule does not support non persistent buffers.") - } - cpp_jit_script_module_register_buffer(self, name, tensor) - invisible(self) - }, - register_module = function(name, module) { - if (inherits(module, "script_module")) { - module <- module$..ptr..() - } - - if (!inherits(module, "torch_script_module")) { - runtime_error("Script modules can only register Script modules children.") - } - - if (is.numeric(name)) { - name <- as.character(name) - } - - cpp_jit_script_module_register_module(self, name, module) - invisible(self) - }, - add_constant = function(name, value) { - cpp_jit_script_module_add_constant(self, name, value) - invisible(self) - }, - to = function(device, non_blocking = FALSE) { - cpp_jit_script_module_to(self, device, non_blocking) - invisible(self) - }, - find_method = function(name) { - cpp_jit_script_module_find_method(self, name) - }, - find_constant = function(name) { - cpp_jit_script_module_find_constant(self, name) - }, - save = function(path) { - cpp_jit_script_module_save(self, path) - }, - save_for_mobile = function(path) { - cpp_jit_script_module_save_for_mobile(self, path) - } - ), - active = list( - parameters = function() { - cpp_jit_script_module_parameters(self, TRUE) - }, - is_training = function() { - cpp_jit_script_module_is_training(self) - }, - buffers = function() { - cpp_jit_script_module_buffers(self, TRUE) - }, - modules = function() { - cpp_jit_script_module_children(self) - } - ) -) - nn_ScriptModule <- R6::R6Class( inherit = nn_Module, lock_objects = FALSE, diff --git a/R/trace.R b/R/trace.R index fc435afbb5..169e6b22f8 100644 --- a/R/trace.R +++ b/R/trace.R @@ -5,7 +5,6 @@ #' recording the operations performed on all the tensors. #' #' The resulting recording of a standalone function produces a `script_function`. -#' In the future we will also support tracing `nn_modules`. #' #' @section Warning: #' From ba47c27632c09dddfcdf858170a62deb613823e0 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 26 Oct 2024 08:51:37 +0200 Subject: [PATCH 2/2] re-add class --- R/script_module.R | 75 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/R/script_module.R b/R/script_module.R index 0f64c4295a..3a954a9459 100644 --- a/R/script_module.R +++ b/R/script_module.R @@ -1,3 +1,78 @@ +ScriptModule <- R7Class( + "torch_script_module", + public = list( + ptr = NULL, + initialize = function(ptr) { + ptr + }, + train = function(mode = TRUE) { + cpp_jit_script_module_train(self, mode) + invisible(self) + }, + register_parameter = function(name, param) { + cpp_jit_script_module_register_parameter(self, name, param, FALSE) + invisible(self) + }, + register_buffer = function(name, tensor, persistent = TRUE) { + if (!persistent) { + runtime_error("ScriptModule does not support non persistent buffers.") + } + cpp_jit_script_module_register_buffer(self, name, tensor) + invisible(self) + }, + register_module = function(name, module) { + if (inherits(module, "script_module")) { + module <- module$..ptr..() + } + + if (!inherits(module, "torch_script_module")) { + runtime_error("Script modules can only register Script modules children.") + } + + if (is.numeric(name)) { + name <- as.character(name) + } + + cpp_jit_script_module_register_module(self, name, module) + invisible(self) + }, + add_constant = function(name, value) { + cpp_jit_script_module_add_constant(self, name, value) + invisible(self) + }, + to = function(device, non_blocking = FALSE) { + cpp_jit_script_module_to(self, device, non_blocking) + invisible(self) + }, + find_method = function(name) { + cpp_jit_script_module_find_method(self, name) + }, + find_constant = function(name) { + cpp_jit_script_module_find_constant(self, name) + }, + save = function(path) { + cpp_jit_script_module_save(self, path) + }, + save_for_mobile = function(path) { + cpp_jit_script_module_save_for_mobile(self, path) + } + ), + active = list( + parameters = function() { + cpp_jit_script_module_parameters(self, TRUE) + }, + is_training = function() { + cpp_jit_script_module_is_training(self) + }, + buffers = function() { + cpp_jit_script_module_buffers(self, TRUE) + }, + modules = function() { + cpp_jit_script_module_children(self) + } + ) +) + nn_ScriptModule <- R6::R6Class( inherit = nn_Module, lock_objects = FALSE,