Skip to content

Commit

Permalink
Fix(jit): Avoid name clashes between tracings (#1247)
Browse files Browse the repository at this point in the history
Fixes #1246
  • Loading branch information
sebffischer authored Jan 21, 2025
1 parent ba479c7 commit ce82e08
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
These can be used as drop-in replacements for `optim_<name>` but are considerably
faster as they wrap the LibTorch implementation of the optimizer.
The biggest speed differences can be observed for complex optimizers such as `AdamW`.
* Fix: Avoid name clashes between multiple calls to `jit_trace` (#1246)

# torch 0.13.0

Expand Down
6 changes: 4 additions & 2 deletions R/trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ module_ignored_names <- c(
"register_parameter", "register_module", "add_module"
)


module_names <- new.env()
make_script_module_name <- function(x) {
paste0(class(x)[1], "_", paste(sample(letters, 24, replace = TRUE), collapse = ""))
new_name <- make.unique(c(names(module_names), class(x)[1]), sep = "")[length(module_names) + 1L]
module_names[[new_name]] <- NULL
new_name
}

create_script_module <- function(mod) {
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/script_module.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# can print the graph

graph(%self : __torch__.nn_linear_ydgabwknrsauujvnjgioueiy,
graph(%self : __torch__.nn_linear,
%3 : Float(10, 10, strides=[10, 1], requires_grad=0, device=cpu)):
%bias : Tensor = prim::GetAttr[name="bias"](%self)
%weight : Tensor = prim::GetAttr[name="weight"](%self)
Expand All @@ -9,7 +9,7 @@

---

graph(%self : __torch__.nn_linear_ydgabwknrsauujvnjgioueiy,
graph(%self : __torch__.nn_linear,
%3 : Float(10, 10, strides=[10, 1], requires_grad=0, device=cpu)):
%bias : Tensor = prim::GetAttr[name="bias"](%self)
%weight : Tensor = prim::GetAttr[name="weight"](%self)
Expand All @@ -18,7 +18,7 @@

# graph_for

graph(%self : __torch__.nn_linear_neebjyloatcfjjfottzlywfy,
graph(%self : __torch__.nn_linear1,
%1 : Tensor):
%bias : Tensor = prim::GetAttr[name="bias"](%self)
%weight : Tensor = prim::GetAttr[name="weight"](%self)
Expand All @@ -32,7 +32,7 @@

---

graph(%self : __torch__.nn_linear_neebjyloatcfjjfottzlywfy,
graph(%self : __torch__.nn_linear1,
%1 : Tensor):
%bias : Tensor = prim::GetAttr[name="bias"](%self)
%weight : Tensor = prim::GetAttr[name="weight"](%self)
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,11 @@ test_that("can save function for mobile", {
f <- jit_load(tmp)
expect_equal_to_tensor(torch_relu(input), f(input))
})

test_that("can define the same method during difference trace-jitting passes (#1246)", {
n <- nn_linear(1, 1)
x <- torch_tensor(1)
nj1 <- withr::with_seed(1, jit_trace(n, x))
nj2 <- expect_error(withr::with_seed(1, jit_trace(n, x)), regexp = NA)
expect_equal(n(x), n(x))
})

0 comments on commit ce82e08

Please sign in to comment.