Skip to content

Commit

Permalink
feat(jit): (un)-serialization of jit script modules (#1240)
Browse files Browse the repository at this point in the history
* feat(jit): (un)-serialization of jit scropt modules

This is useful for sending jitted modules between R processes.

Resolves Issue #1236

* cleanup, docs

* trigger ci

* export jit serialization functions
  • Loading branch information
sebffischer authored Jan 21, 2025
1 parent 8e13333 commit ecdf13b
Show file tree
Hide file tree
Showing 14 changed files with 221 additions and 12 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ export(jit_ops)
export(jit_save)
export(jit_save_for_mobile)
export(jit_scalar)
export(jit_serialize)
export(jit_trace)
export(jit_trace_module)
export(jit_tuple)
export(jit_unserialize)
export(linalg_cholesky)
export(linalg_cholesky_ex)
export(linalg_cond)
Expand Down
8 changes: 8 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -15353,6 +15353,14 @@ cpp_jit_script_module_save <- function(self, path) {
invisible(.Call(`_torch_cpp_jit_script_module_save`, self, path))
}

cpp_jit_script_module_serialize <- function(self) {
.Call(`_torch_cpp_jit_script_module_serialize`, self)
}

cpp_jit_script_module_unserialize <- function(input) {
.Call(`_torch_cpp_jit_script_module_unserialize`, input)
}

cpp_jit_script_module_save_for_mobile <- function(self, path) {
invisible(.Call(`_torch_cpp_jit_script_module_save_for_mobile`, self, path))
}
Expand Down
5 changes: 3 additions & 2 deletions R/save.R
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ torch_load <- function(path, device = "cpu") {
if (is_rds(path)) {
return(legacy_torch_load(path, device))
}

if (is.null(device)) {
cli::cli_abort("Unexpected device {.val NULL}")
}
Expand Down Expand Up @@ -333,6 +333,7 @@ legacy_torch_load <- function(path, device = "cpu") {

#' Serialize a torch object returning a raw object
#'

#' It's just a wraper around [torch_save()].
#'
#' @inheritParams torch_save
Expand All @@ -341,7 +342,7 @@ legacy_torch_load <- function(path, device = "cpu") {
#' @returns A raw vector containing the serialized object. Can be reloaded using
#' [torch_load()].
#' @family torch_save
#'
#'
#' @export
#' @concept serialization
torch_serialize <- function(obj, ...) {
Expand Down
3 changes: 3 additions & 0 deletions R/script_module.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ ScriptModule <- R7Class(
save = function(path) {
cpp_jit_script_module_save(self, path)
},
serialize = function(path) {
cpp_jit_script_module_serialize(self)
},
save_for_mobile = function(path) {
cpp_jit_script_module_save_for_mobile(self, path)
}
Expand Down
36 changes: 36 additions & 0 deletions R/trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,42 @@ jit_save <- function(obj, path, ...) {
invisible(obj)
}

#' @title Serialize a Script Module
#' @description
#' Serializes a script module and returns it as a raw vector.
#' You can read the object again using [`jit_unserialize`].
#' @param obj (`script_module`)\cr
#' Model to be serialized.
#' @return `raw()`
#' @examples
#' model <- jit_trace(nn_linear(1, 1), torch_randn(1))
#' serialized <- jit_serialize(model)
#' @export
jit_serialize <- function(obj) {
if (inherits(obj, "script_module")) {
obj$..ptr..()$serialize()
} else {
value_error("Only `script_module` can be serialized with `jit_serialize`.")
}
}

#' @title Unserialize a Script Module
#' @description
#' Unserializes a script module from a raw vector (generated with [`jit_seriaize`]`).
#' @param obj (`raw`)\cr
#' Serialized model.
#' @return `script_module`
#' model <- jit_trace(nn_linear(1, 1), torch_randn(1))
#' serialized <- jit_serialize(model)
#' model2 <- jit_unserialize(serialized)
#' @export
jit_unserialize <- function(obj) {
if (!is.raw(obj)) {
value_error("`obj` to be deserialized must be a raw vector.")
}
cpp_jit_script_module_unserialize(obj)
}

ScriptFunction <- R6::R6Class(
"ScriptFunction",
public = list(
Expand Down
21 changes: 21 additions & 0 deletions inst/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,25 @@ HOST_API void lantern_ScriptModule_save (void* self, void* path)

}

LANTERN_API void* (LANTERN_PTR _lantern_ScriptModule_serialize) (void* self);
HOST_API void* lantern_ScriptModule_serialize (void* self)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_ScriptModule_serialize(self);
LANTERN_HOST_HANDLER;
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_ScriptModule_unserialize) (void* self);
HOST_API void* lantern_ScriptModule_unserialize (void* self)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_ScriptModule_unserialize(self);
LANTERN_HOST_HANDLER;
return ret;
}


LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_save_for_mobile) (void* self, void* path);
HOST_API void lantern_ScriptModule_save_for_mobile (void* self, void* path)
{
Expand Down Expand Up @@ -10859,6 +10878,8 @@ LOAD_SYMBOL(_lantern_ScriptModule_add_constant);
LOAD_SYMBOL(_lantern_ScriptModule_find_constant);
LOAD_SYMBOL(_lantern_ScriptModule_add_method);
LOAD_SYMBOL(_lantern_ScriptModule_save);
LOAD_SYMBOL(_lantern_ScriptModule_serialize);
LOAD_SYMBOL(_lantern_ScriptModule_unserialize);
LOAD_SYMBOL(_lantern_ScriptModule_save_for_mobile);
LOAD_SYMBOL(_lantern_vector_Scalar_new);
LOAD_SYMBOL(_lantern_vector_Scalar_push_back);
Expand Down
25 changes: 25 additions & 0 deletions man/jit_serialize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions man/jit_unserialize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47494,6 +47494,28 @@ BEGIN_RCPP
return R_NilValue;
END_RCPP
}
// cpp_jit_script_module_serialize
SEXP cpp_jit_script_module_serialize(XPtrTorchScriptModule self);
RcppExport SEXP _torch_cpp_jit_script_module_serialize(SEXP selfSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< XPtrTorchScriptModule >::type self(selfSEXP);
rcpp_result_gen = Rcpp::wrap(cpp_jit_script_module_serialize(self));
return rcpp_result_gen;
END_RCPP
}
// cpp_jit_script_module_unserialize
SEXP cpp_jit_script_module_unserialize(SEXP input);
RcppExport SEXP _torch_cpp_jit_script_module_unserialize(SEXP inputSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< SEXP >::type input(inputSEXP);
rcpp_result_gen = Rcpp::wrap(cpp_jit_script_module_unserialize(input));
return rcpp_result_gen;
END_RCPP
}
// cpp_jit_script_module_save_for_mobile
void cpp_jit_script_module_save_for_mobile(XPtrTorchScriptModule self, XPtrTorchstring path);
RcppExport SEXP _torch_cpp_jit_script_module_save_for_mobile(SEXP selfSEXP, SEXP pathSEXP) {
Expand Down Expand Up @@ -51942,6 +51964,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_torch_cpp_jit_script_module_add_method", (DL_FUNC) &_torch_cpp_jit_script_module_add_method, 2},
{"_torch_cpp_jit_script_module_find_constant", (DL_FUNC) &_torch_cpp_jit_script_module_find_constant, 2},
{"_torch_cpp_jit_script_module_save", (DL_FUNC) &_torch_cpp_jit_script_module_save, 2},
{"_torch_cpp_jit_script_module_serialize", (DL_FUNC) &_torch_cpp_jit_script_module_serialize, 1},
{"_torch_cpp_jit_script_module_unserialize", (DL_FUNC) &_torch_cpp_jit_script_module_unserialize, 1},
{"_torch_cpp_jit_script_module_save_for_mobile", (DL_FUNC) &_torch_cpp_jit_script_module_save_for_mobile, 2},
{"_torch_test_stack", (DL_FUNC) &_torch_test_stack, 1},
{"_torch_cpp_Tensor_storage", (DL_FUNC) &_torch_cpp_Tensor_storage, 1},
Expand Down
21 changes: 21 additions & 0 deletions src/lantern/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,25 @@ HOST_API void lantern_ScriptModule_save (void* self, void* path)

}

LANTERN_API void* (LANTERN_PTR _lantern_ScriptModule_serialize) (void* self);
HOST_API void* lantern_ScriptModule_serialize (void* self)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_ScriptModule_serialize(self);
LANTERN_HOST_HANDLER;
return ret;
}

LANTERN_API void* (LANTERN_PTR _lantern_ScriptModule_unserialize) (void* self);
HOST_API void* lantern_ScriptModule_unserialize (void* self)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_ScriptModule_unserialize(self);
LANTERN_HOST_HANDLER;
return ret;
}


LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_save_for_mobile) (void* self, void* path);
HOST_API void lantern_ScriptModule_save_for_mobile (void* self, void* path)
{
Expand Down Expand Up @@ -10859,6 +10878,8 @@ LOAD_SYMBOL(_lantern_ScriptModule_add_constant);
LOAD_SYMBOL(_lantern_ScriptModule_find_constant);
LOAD_SYMBOL(_lantern_ScriptModule_add_method);
LOAD_SYMBOL(_lantern_ScriptModule_save);
LOAD_SYMBOL(_lantern_ScriptModule_serialize);
LOAD_SYMBOL(_lantern_ScriptModule_unserialize);
LOAD_SYMBOL(_lantern_ScriptModule_save_for_mobile);
LOAD_SYMBOL(_lantern_vector_Scalar_new);
LOAD_SYMBOL(_lantern_vector_Scalar_push_back);
Expand Down
21 changes: 21 additions & 0 deletions src/lantern/src/ScriptModule.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <torch/csrc/jit/serialization/import.h>
#define LANTERN_BUILD
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/script.h> // One-stop header.
Expand Down Expand Up @@ -191,6 +192,26 @@ void _lantern_ScriptModule_save(void* self, void* path) {
LANTERN_FUNCTION_END_VOID
}

void* _lantern_ScriptModule_serialize(void* self) {
LANTERN_FUNCTION_START
auto self_ = reinterpret_cast<torch::jit::script::Module*>(self);
std::ostringstream oss(std::ios::binary);
self_->save(oss);
auto str = std::string(oss.str());
return make_raw::string(str);
LANTERN_FUNCTION_END
}

void* _lantern_ScriptModule_unserialize(void* s) {
LANTERN_FUNCTION_START
auto str = from_raw::string(s);
std::istringstream input_stream(str);
torch::jit::script::Module module;
module = torch::jit::load(input_stream);
return (void*)new torch::jit::script::Module(module);
LANTERN_FUNCTION_END
}

void* _lantern_ScriptMethod_graph_print(void* self) {
LANTERN_FUNCTION_START
auto self_ = reinterpret_cast<torch::jit::script::Method*>(self);
Expand Down
21 changes: 21 additions & 0 deletions src/script_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ void cpp_jit_script_module_save(XPtrTorchScriptModule self,
lantern_ScriptModule_save(self.get(), path.get());
}

// [[Rcpp::export]]
SEXP cpp_jit_script_module_serialize(XPtrTorchScriptModule self) {
torch::string out = lantern_ScriptModule_serialize(self.get());

const char* v = lantern_string_get(out.get());
auto output = std::string(v, lantern_string_size(out.get()));
lantern_const_char_delete(v);

Rcpp::RawVector raw_vec(output.size());
memcpy(&raw_vec[0], output.c_str(), output.size());

return raw_vec;
}

// [[Rcpp::export]]
SEXP cpp_jit_script_module_unserialize(SEXP input) {
auto raw_vec = Rcpp::as<Rcpp::RawVector>(input);
torch::string v = std::string((char*)&raw_vec[0], raw_vec.size());
return XPtrTorchScriptModule(lantern_ScriptModule_unserialize(v.get()));
}

// [[Rcpp::export]]
void cpp_jit_script_module_save_for_mobile(XPtrTorchScriptModule self,
XPtrTorchstring path) {
Expand Down
14 changes: 4 additions & 10 deletions tests/testthat/test-jit-ops.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@ test_that("can access operators via ops object", {
# matmul, default use
res <- jit_ops$aten$matmul(torch::torch_ones(5, 4), torch::torch_rand(4, 5))
expect_equal(dim(res), c(5, 5))

# matmul, passing out tensor
t1 <- torch::torch_ones(4, 4)
t2 <- torch::torch_eye(4)
out <- torch::torch_zeros(4, 4)
jit_ops$aten$matmul(t1, t2, out)
expect_equal_to_tensor(t1, out)

# split, returning two tensors in a list of length 2
res_torch <- torch_split(torch::torch_arange(0, 3), 2, 1)
res_jit <- jit_ops$aten$split(torch::torch_arange(0, 3), torch::jit_scalar(2L), torch::jit_scalar(0L))
expect_length(res_jit, 2)
expect_equal_to_tensor(res_jit[[1]], res_torch[[1]])
expect_equal_to_tensor(res_jit[[2]], res_torch[[2]])

# split, returning a single tensor
res_torch <- torch_split(torch::torch_arange(0, 3), 4, 1)
res_jit <- jit_ops$aten$split(torch::torch_arange(0, 3), torch::jit_scalar(4L), torch::jit_scalar(0L))
expect_length(res_jit, 1)
expect_equal_to_tensor(res_jit[[1]], res_torch[[1]])

# linalg_qr always returns a list
m <- torch_eye(5)/5
res_torch <- linalg_qr(m)
Expand All @@ -37,9 +37,3 @@ test_that("can print ops objects at different levels", {
expect_snapshot(jit_ops$prim$ChunkSizes)
expect_snapshot(jit_ops$aten$fft_fft)
})






11 changes: 11 additions & 0 deletions tests/testthat/test-trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,17 @@ test_that("can save function for mobile", {
expect_equal_to_tensor(torch_relu(input), f(input))
})

test_that("can serialize to raw vector and deserialize again", {
n1 <- jit_trace(nn_linear(1, 1), torch_randn(1, 1))
n1$parameters$bias$requires_grad_(FALSE)
x <- jit_serialize(n1)
expect_true(is.raw(x))
n2 <- jit_unserialize(x)
x <- torch_randn(1)
expect_equal_to_tensor(n1(x), n2(x))
expect_false(n2$parameters$bias$requires_grad)
})

test_that("can define the same method during difference trace-jitting passes (#1246)", {
n <- nn_linear(1, 1)
x <- torch_tensor(1)
Expand Down

0 comments on commit ecdf13b

Please sign in to comment.